From bdd75360f43d06e654dc93eeeaed5cafde0ef8f6 Mon Sep 17 00:00:00 2001 From: Jayant Date: Mon, 22 Jul 2024 16:50:11 +0800 Subject: [PATCH] refactor(multi-services): refactoring service and method routing for multi-services (#1439) --- internal/mocks/remote/servicesearcher.go | 64 +++++++++++ pkg/remote/codec/header_codec_test.go | 15 +-- pkg/remote/codec/util_test.go | 17 +-- pkg/remote/message.go | 65 ++++------- pkg/remote/option.go | 5 +- pkg/remote/trans/default_server_handler.go | 6 +- .../trans/default_server_handler_test.go | 22 ++-- .../trans/detection/server_handler_test.go | 20 +--- pkg/remote/trans/gonet/trans_server_test.go | 13 +-- .../trans/netpoll/http_client_handler_test.go | 15 +-- pkg/remote/trans/netpoll/trans_server_test.go | 12 +- pkg/remote/trans/netpollmux/server_handler.go | 20 +--- .../trans/netpollmux/server_handler_test.go | 20 +--- pkg/remote/trans/nphttp2/mocks_test.go | 2 +- pkg/remote/trans/nphttp2/server_handler.go | 16 +-- pkg/transmeta/ttheader.go | 25 +++- pkg/transmeta/ttheader_test.go | 3 +- server/option_advanced_test.go | 3 +- server/option_test.go | 14 +-- server/server.go | 28 ++--- server/server_test.go | 31 +++-- server/service.go | 108 +++++++----------- server/service_test.go | 53 ++------- 23 files changed, 241 insertions(+), 336 deletions(-) create mode 100644 internal/mocks/remote/servicesearcher.go diff --git a/internal/mocks/remote/servicesearcher.go b/internal/mocks/remote/servicesearcher.go new file mode 100644 index 0000000000..5b7f66e11b --- /dev/null +++ b/internal/mocks/remote/servicesearcher.go @@ -0,0 +1,64 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package remote + +import ( + "github.com/cloudwego/kitex/internal/mocks" + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +type MockSvcSearcher struct { + svcMap map[string]*serviceinfo.ServiceInfo + methodSvcMap map[string]*serviceinfo.ServiceInfo +} + +func NewMockSvcSearcher(svcMap, methodSvcMap map[string]*serviceinfo.ServiceInfo) *MockSvcSearcher { + return &MockSvcSearcher{svcMap: svcMap, methodSvcMap: methodSvcMap} +} + +func NewDefaultSvcSearcher() *MockSvcSearcher { + svcInfo := mocks.ServiceInfo() + s := map[string]*serviceinfo.ServiceInfo{ + mocks.MockServiceName: svcInfo, + } + m := map[string]*serviceinfo.ServiceInfo{ + mocks.MockMethod: svcInfo, + mocks.MockExceptionMethod: svcInfo, + mocks.MockErrorMethod: svcInfo, + mocks.MockOnewayMethod: svcInfo, + } + return &MockSvcSearcher{svcMap: s, methodSvcMap: m} +} + +func (s *MockSvcSearcher) SearchService(svcName, methodName string, strict bool) *serviceinfo.ServiceInfo { + if strict { + if svc := s.svcMap[svcName]; svc != nil { + return svc + } + return nil + } + var svc *serviceinfo.ServiceInfo + if svcName == "" { + svc = s.methodSvcMap[methodName] + } else { + svc = s.svcMap[svcName] + } + if svc != nil { + return svc + } + return nil +} diff --git a/pkg/remote/codec/header_codec_test.go b/pkg/remote/codec/header_codec_test.go index 0c8066955e..af60e4c91c 100644 --- a/pkg/remote/codec/header_codec_test.go +++ b/pkg/remote/codec/header_codec_test.go @@ -25,6 +25,7 @@ import ( "github.com/bytedance/gopkg/cloud/metainfo" "github.com/cloudwego/kitex/internal/mocks" + mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/discovery" "github.com/cloudwego/kitex/pkg/remote" @@ -32,7 +33,6 @@ import ( "github.com/cloudwego/kitex/pkg/remote/transmeta" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" - "github.com/cloudwego/kitex/pkg/serviceinfo" tm "github.com/cloudwego/kitex/pkg/transmeta" "github.com/cloudwego/kitex/transport" ) @@ -315,17 +315,8 @@ var ( func initServerRecvMsg() remote.Message { svcInfo := mocks.ServiceInfo() - svcSearchMap := map[string]*serviceinfo.ServiceInfo{ - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, - mocks.MockMethod: svcInfo, - mocks.MockExceptionMethod: svcInfo, - mocks.MockErrorMethod: svcInfo, - mocks.MockOnewayMethod: svcInfo, - } - msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, mockSvrRPCInfo, remote.Call, remote.Server, false) + svcSearcher := mocksremote.NewDefaultSvcSearcher() + msg := remote.NewMessageWithNewer(svcInfo, svcSearcher, mockSvrRPCInfo, remote.Call, remote.Server) return msg } diff --git a/pkg/remote/codec/util_test.go b/pkg/remote/codec/util_test.go index b177fe97b6..b277c8fb89 100644 --- a/pkg/remote/codec/util_test.go +++ b/pkg/remote/codec/util_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/cloudwego/kitex/internal/mocks" + mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -30,17 +31,8 @@ func TestSetOrCheckMethodName(t *testing.T) { ri := rpcinfo.NewRPCInfo(nil, rpcinfo.NewEndpointInfo("", "mock", nil, nil), rpcinfo.NewServerInvocation(), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) svcInfo := mocks.ServiceInfo() - svcSearchMap := map[string]*serviceinfo.ServiceInfo{ - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, - mocks.MockMethod: svcInfo, - mocks.MockExceptionMethod: svcInfo, - mocks.MockErrorMethod: svcInfo, - mocks.MockOnewayMethod: svcInfo, - } - msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) + svcSearcher := mocksremote.NewDefaultSvcSearcher() + msg := remote.NewMessageWithNewer(svcInfo, svcSearcher, ri, remote.Call, remote.Server) err := SetOrCheckMethodName("mock", msg) test.Assert(t, err == nil) ri = msg.RPCInfo() @@ -49,7 +41,8 @@ func TestSetOrCheckMethodName(t *testing.T) { test.Assert(t, ri.Invocation().MethodName() == "mock") test.Assert(t, ri.To().Method() == "mock") - msg = remote.NewMessageWithNewer(svcInfo, map[string]*serviceinfo.ServiceInfo{}, ri, remote.Call, remote.Server, false) + m := map[string]*serviceinfo.ServiceInfo{} + msg = remote.NewMessageWithNewer(svcInfo, mocksremote.NewMockSvcSearcher(m, m), ri, remote.Call, remote.Server) err = SetOrCheckMethodName("dummy", msg) test.Assert(t, err != nil) test.Assert(t, err.Error() == "unknown method dummy") diff --git a/pkg/remote/message.go b/pkg/remote/message.go index 23c8d83a40..ec33de8030 100644 --- a/pkg/remote/message.go +++ b/pkg/remote/message.go @@ -22,7 +22,6 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/kitex/pkg/utils" "github.com/cloudwego/kitex/transport" ) @@ -81,6 +80,12 @@ func NewProtocolInfo(tp transport.Protocol, ct serviceinfo.PayloadCodec) Protoco } } +// ServiceSearcher is used to search the service info by service name and method name, +// strict equals to true means the service name must match the registered service name. +type ServiceSearcher interface { + SearchService(svcName, methodName string, strict bool) *serviceinfo.ServiceInfo +} + // Message is the core abstraction for Kitex message. type Message interface { RPCInfo() rpcinfo.RPCInfo @@ -115,15 +120,14 @@ func NewMessage(data interface{}, svcInfo *serviceinfo.ServiceInfo, ri rpcinfo.R } // NewMessageWithNewer creates a new Message and set data later. -func NewMessageWithNewer(targetSvcInfo *serviceinfo.ServiceInfo, svcSearchMap map[string]*serviceinfo.ServiceInfo, ri rpcinfo.RPCInfo, msgType MessageType, rpcRole RPCRole, refuseTrafficWithoutServiceName bool) Message { +func NewMessageWithNewer(targetSvcInfo *serviceinfo.ServiceInfo, svcSearcher ServiceSearcher, ri rpcinfo.RPCInfo, msgType MessageType, rpcRole RPCRole) Message { msg := messagePool.Get().(*message) msg.rpcInfo = ri msg.targetSvcInfo = targetSvcInfo - msg.svcSearchMap = svcSearchMap + msg.svcSearcher = svcSearcher msg.msgType = msgType msg.rpcRole = rpcRole msg.transInfo = transInfoPool.Get().(*transInfo) - msg.refuseTrafficWithoutServiceName = refuseTrafficWithoutServiceName return msg } @@ -139,19 +143,18 @@ func newMessage() interface{} { } type message struct { - msgType MessageType - data interface{} - rpcInfo rpcinfo.RPCInfo - targetSvcInfo *serviceinfo.ServiceInfo - svcSearchMap map[string]*serviceinfo.ServiceInfo - rpcRole RPCRole - compressType CompressType - payloadSize int - transInfo TransInfo - tags map[string]interface{} - protocol ProtocolInfo - payloadCodec PayloadCodec - refuseTrafficWithoutServiceName bool + msgType MessageType + data interface{} + rpcInfo rpcinfo.RPCInfo + targetSvcInfo *serviceinfo.ServiceInfo + svcSearcher ServiceSearcher + rpcRole RPCRole + compressType CompressType + payloadSize int + transInfo TransInfo + tags map[string]interface{} + protocol ProtocolInfo + payloadCodec PayloadCodec } func (m *message) zero() { @@ -190,19 +193,7 @@ func (m *message) SpecifyServiceInfo(svcName, methodName string) (*serviceinfo.S } return m.targetSvcInfo, nil } - if svcName == "" && m.refuseTrafficWithoutServiceName { - return nil, NewTransErrorWithMsg(NoServiceName, "no service name while the server has WithRefuseTrafficWithoutServiceName option enabled") - } - var key string - // when client does not pass svcName or passes a special service name, fallback to searching for svcInfo by method name alone - // note: This special name fallback logic shouldn't ideally be here, but it's needed to keep compatibility with older versions (<= v0.10.1) - // due to a mistake in the first release of the multi-service feature - if svcName == "" || isSpecialName(svcName) { - key = methodName - } else { - key = BuildMultiServiceKey(svcName, methodName) - } - svcInfo := m.svcSearchMap[key] + svcInfo := m.svcSearcher.SearchService(svcName, methodName, false) if svcInfo == nil { return nil, NewTransErrorWithMsg(UnknownService, fmt.Sprintf("unknown service %s, method %s", svcName, methodName)) } @@ -290,10 +281,6 @@ func (m *message) Recycle() { messagePool.Put(m) } -func isSpecialName(svcName string) bool { - return svcName == serviceinfo.GenericService || svcName == serviceinfo.CombineService || svcName == serviceinfo.CombineService_ -} - // TransInfo contains transport information. type TransInfo interface { TransStrInfo() map[string]string @@ -373,13 +360,3 @@ func FillSendMsgFromRecvMsg(recvMsg, sendMsg Message) { sendMsg.SetProtocolInfo(recvMsg.ProtocolInfo()) sendMsg.SetPayloadCodec(recvMsg.PayloadCodec()) } - -// BuildMultiServiceKey is used to create a key to search svcInfo from svcSearchMap. -func BuildMultiServiceKey(serviceName, methodName string) string { - var builder utils.StringBuilder - builder.Grow(len(serviceName) + len(methodName) + 1) - builder.WriteString(serviceName) - builder.WriteString(".") - builder.WriteString(methodName) - return builder.String() -} diff --git a/pkg/remote/option.go b/pkg/remote/option.go index afe0f567b9..71e81c2500 100644 --- a/pkg/remote/option.go +++ b/pkg/remote/option.go @@ -72,7 +72,7 @@ func (o *Option) AppendBoundHandler(h BoundHandler) { type ServerOption struct { TargetSvcInfo *serviceinfo.ServiceInfo - SvcSearchMap map[string]*serviceinfo.ServiceInfo + SvcSearcher ServiceSearcher TransServerFactory TransServerFactory @@ -113,9 +113,6 @@ type ServerOption struct { GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error - // RefuseTrafficWithoutServiceName is used for a server with multi services - RefuseTrafficWithoutServiceName bool - Option // invoking chain with recv/send middlewares for streaming APIs diff --git a/pkg/remote/trans/default_server_handler.go b/pkg/remote/trans/default_server_handler.go index 39c3abbd8d..6643c035ec 100644 --- a/pkg/remote/trans/default_server_handler.go +++ b/pkg/remote/trans/default_server_handler.go @@ -37,7 +37,7 @@ func NewDefaultSvrTransHandler(opt *remote.ServerOption, ext Extension) (remote. svrHdlr := &svrTransHandler{ opt: opt, codec: opt.Codec, - svcSearchMap: opt.SvcSearchMap, + svcSearcher: opt.SvcSearcher, targetSvcInfo: opt.TargetSvcInfo, ext: ext, } @@ -50,7 +50,7 @@ func NewDefaultSvrTransHandler(opt *remote.ServerOption, ext Extension) (remote. type svrTransHandler struct { opt *remote.ServerOption - svcSearchMap map[string]*serviceinfo.ServiceInfo + svcSearcher remote.ServiceSearcher targetSvcInfo *serviceinfo.ServiceInfo inkHdlFunc endpoint.Endpoint codec remote.Codec @@ -167,7 +167,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) }() ctx = t.startTracer(ctx, ri) ctx = t.startProfiler(ctx) - recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, ri, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName) + recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearcher, ri, remote.Call, remote.Server) recvMsg.SetPayloadCodec(t.opt.PayloadCodec) ctx, err = t.transPipe.Read(ctx, conn, recvMsg) if err != nil { diff --git a/pkg/remote/trans/default_server_handler_test.go b/pkg/remote/trans/default_server_handler_test.go index c8ed141988..af750ddae3 100644 --- a/pkg/remote/trans/default_server_handler_test.go +++ b/pkg/remote/trans/default_server_handler_test.go @@ -25,6 +25,7 @@ import ( "github.com/golang/mock/gomock" "github.com/cloudwego/kitex/internal/mocks" + remotemocks "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/mocks/stats" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/kerrors" @@ -34,17 +35,8 @@ import ( ) var ( - svcInfo = mocks.ServiceInfo() - svcSearchMap = map[string]*serviceinfo.ServiceInfo{ - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, - mocks.MockMethod: svcInfo, - mocks.MockExceptionMethod: svcInfo, - mocks.MockErrorMethod: svcInfo, - mocks.MockOnewayMethod: svcInfo, - } + svcInfo = mocks.ServiceInfo() + svcSearcher = remotemocks.NewDefaultSvcSearcher() ) func TestDefaultSvrTransHandler(t *testing.T) { @@ -72,7 +64,7 @@ func TestDefaultSvrTransHandler(t *testing.T) { return nil }, }, - SvcSearchMap: svcSearchMap, + SvcSearcher: svcSearcher, TargetSvcInfo: svcInfo, } @@ -139,7 +131,7 @@ func TestSvrTransHandlerBizError(t *testing.T) { return nil }, }, - SvcSearchMap: svcSearchMap, + SvcSearcher: svcSearcher, TargetSvcInfo: svcInfo, TracerCtl: tracerCtl, InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { @@ -198,7 +190,7 @@ func TestSvrTransHandlerReadErr(t *testing.T) { return mockErr }, }, - SvcSearchMap: svcSearchMap, + SvcSearcher: svcSearcher, TargetSvcInfo: svcInfo, TracerCtl: tracerCtl, InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { @@ -256,7 +248,7 @@ func TestSvrTransHandlerOnReadHeartbeat(t *testing.T) { return nil }, }, - SvcSearchMap: svcSearchMap, + SvcSearcher: svcSearcher, TargetSvcInfo: svcInfo, TracerCtl: tracerCtl, InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { diff --git a/pkg/remote/trans/detection/server_handler_test.go b/pkg/remote/trans/detection/server_handler_test.go index 04343edc18..5f4abfc3a8 100644 --- a/pkg/remote/trans/detection/server_handler_test.go +++ b/pkg/remote/trans/detection/server_handler_test.go @@ -36,7 +36,6 @@ import ( "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/utils" ) @@ -49,22 +48,13 @@ var ( } return grpc.ClientPrefaceLen }() - svcInfo = mocks.ServiceInfo() - svcSearchMap = map[string]*serviceinfo.ServiceInfo{ - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, - mocks.MockMethod: svcInfo, - mocks.MockExceptionMethod: svcInfo, - mocks.MockErrorMethod: svcInfo, - mocks.MockOnewayMethod: svcInfo, - } + svcInfo = mocks.ServiceInfo() + svcSearcher = remote_mocks.NewDefaultSvcSearcher() ) func TestServerHandlerCall(t *testing.T) { transHdler, _ := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ - SvcSearchMap: svcSearchMap, + SvcSearcher: svcSearcher, TargetSvcInfo: svcInfo, }) @@ -134,7 +124,7 @@ func TestOnError(t *testing.T) { ctrl.Finish() }() transHdler, err := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ - SvcSearchMap: svcSearchMap, + SvcSearcher: svcSearcher, TargetSvcInfo: svcInfo, }) test.Assert(t, err == nil) @@ -164,7 +154,7 @@ func TestOnError(t *testing.T) { // TestOnInactive covers onInactive() codes to check panic func TestOnInactive(t *testing.T) { transHdler, err := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ - SvcSearchMap: svcSearchMap, + SvcSearcher: svcSearcher, TargetSvcInfo: svcInfo, }) test.Assert(t, err == nil) diff --git a/pkg/remote/trans/gonet/trans_server_test.go b/pkg/remote/trans/gonet/trans_server_test.go index fda743e03f..9c8086c5f0 100644 --- a/pkg/remote/trans/gonet/trans_server_test.go +++ b/pkg/remote/trans/gonet/trans_server_test.go @@ -24,10 +24,10 @@ import ( "time" "github.com/cloudwego/kitex/internal/mocks" + mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/utils" ) @@ -62,16 +62,7 @@ func TestMain(m *testing.M) { return nil }, }, - SvcSearchMap: map[string]*serviceinfo.ServiceInfo{ - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, - mocks.MockMethod: svcInfo, - mocks.MockExceptionMethod: svcInfo, - mocks.MockErrorMethod: svcInfo, - mocks.MockOnewayMethod: svcInfo, - }, + SvcSearcher: mocksremote.NewDefaultSvcSearcher(), TargetSvcInfo: svcInfo, TracerCtl: &rpcinfo.TraceController{}, } diff --git a/pkg/remote/trans/netpoll/http_client_handler_test.go b/pkg/remote/trans/netpoll/http_client_handler_test.go index fed8d05a30..403df6a17e 100644 --- a/pkg/remote/trans/netpoll/http_client_handler_test.go +++ b/pkg/remote/trans/netpoll/http_client_handler_test.go @@ -26,11 +26,11 @@ import ( "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/internal/mocks" + remotemocks "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/trans" "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/pkg/serviceinfo" ) var ( @@ -119,19 +119,10 @@ func TestHTTPRead(t *testing.T) { func TestHTTPOnMessage(t *testing.T) { // 1. prepare mock data svcInfo := mocks.ServiceInfo() - svcSearchMap := map[string]*serviceinfo.ServiceInfo{ - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, - mocks.MockMethod: svcInfo, - mocks.MockExceptionMethod: svcInfo, - mocks.MockErrorMethod: svcInfo, - mocks.MockOnewayMethod: svcInfo, - } + svcSearcher := remotemocks.NewDefaultSvcSearcher() ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, method), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearcher, ri, remote.Call, remote.Server) sendMsg := remote.NewMessage(svcInfo.MethodInfo(method).NewResult(), svcInfo, ri, remote.Reply, remote.Server) // 2. test diff --git a/pkg/remote/trans/netpoll/trans_server_test.go b/pkg/remote/trans/netpoll/trans_server_test.go index f8c9ef1502..24afb15691 100644 --- a/pkg/remote/trans/netpoll/trans_server_test.go +++ b/pkg/remote/trans/netpoll/trans_server_test.go @@ -32,7 +32,6 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/utils" ) @@ -67,16 +66,7 @@ func TestMain(m *testing.M) { return nil }, }, - SvcSearchMap: map[string]*serviceinfo.ServiceInfo{ - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, - mocks.MockMethod: svcInfo, - mocks.MockExceptionMethod: svcInfo, - mocks.MockErrorMethod: svcInfo, - mocks.MockOnewayMethod: svcInfo, - }, + SvcSearcher: mocksremote.NewDefaultSvcSearcher(), TargetSvcInfo: svcInfo, TracerCtl: &rpcinfo.TraceController{}, } diff --git a/pkg/remote/trans/netpollmux/server_handler.go b/pkg/remote/trans/netpollmux/server_handler.go index eb3dbc3d06..279c76c45e 100644 --- a/pkg/remote/trans/netpollmux/server_handler.go +++ b/pkg/remote/trans/netpollmux/server_handler.go @@ -65,7 +65,7 @@ func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) { svrHdlr := &svrTransHandler{ opt: opt, codec: opt.Codec, - svcSearchMap: opt.SvcSearchMap, + svcSearcher: opt.SvcSearcher, targetSvcInfo: opt.TargetSvcInfo, ext: np.NewNetpollConnExtension(), } @@ -84,7 +84,7 @@ var _ remote.ServerTransHandler = &svrTransHandler{} type svrTransHandler struct { opt *remote.ServerOption - svcSearchMap map[string]*serviceinfo.ServiceInfo + svcSearcher remote.ServiceSearcher targetSvcInfo *serviceinfo.ServiceInfo inkHdlFunc endpoint.Endpoint codec remote.Codec @@ -239,7 +239,7 @@ func (t *svrTransHandler) task(muxSvrConnCtx context.Context, conn net.Conn, rea }() // read - recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, rpcInfo, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName) + recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearcher, rpcInfo, remote.Call, remote.Server) bufReader := np.NewReaderByteBuffer(reader) err = t.readWithByteBuffer(ctx, bufReader, recvMsg) if err != nil { @@ -328,8 +328,7 @@ func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error { iv.SetSeqID(0) ri := rpcinfo.NewRPCInfo(nil, nil, iv, nil, nil) data := NewControlFrame() - svcInfo := t.getSvcInfo() - msg := remote.NewMessage(data, svcInfo, ri, remote.Reply, remote.Server) + msg := remote.NewMessage(data, nil, ri, remote.Reply, remote.Server) msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift)) msg.TransInfo().TransStrInfo()[transmeta.HeaderConnectionReadyToReset] = "1" @@ -478,17 +477,6 @@ func (t *svrTransHandler) finishTracer(ctx context.Context, ri rpcinfo.RPCInfo, rpcStats.SetLevel(sl) } -// getSvcInfo is used to get one ServiceInfo -func (t *svrTransHandler) getSvcInfo() *serviceinfo.ServiceInfo { - if t.targetSvcInfo != nil { - return t.targetSvcInfo - } - for _, svcInfo := range t.svcSearchMap { - return svcInfo - } - return nil -} - func getRemoteInfo(ri rpcinfo.RPCInfo, conn net.Conn) (string, net.Addr) { rAddr := conn.RemoteAddr() if ri == nil { diff --git a/pkg/remote/trans/netpollmux/server_handler_test.go b/pkg/remote/trans/netpollmux/server_handler_test.go index 812b1747ad..163a25bf51 100644 --- a/pkg/remote/trans/netpollmux/server_handler_test.go +++ b/pkg/remote/trans/netpollmux/server_handler_test.go @@ -28,6 +28,7 @@ import ( "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/internal/mocks" + mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" @@ -43,17 +44,8 @@ var ( addr = utils.NewNetAddr("tcp", addrStr) method = "mock" - svcInfo = mocks.ServiceInfo() - svcSearchMap = map[string]*serviceinfo.ServiceInfo{ - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, - mocks.MockMethod: svcInfo, - mocks.MockExceptionMethod: svcInfo, - mocks.MockErrorMethod: svcInfo, - mocks.MockOnewayMethod: svcInfo, - } + svcInfo = mocks.ServiceInfo() + svcSearcher = mocksremote.NewDefaultSvcSearcher() ) func newTestRpcInfo() rpcinfo.RPCInfo { @@ -91,7 +83,7 @@ func init() { return err }, }, - SvcSearchMap: svcSearchMap, + SvcSearcher: svcSearcher, TargetSvcInfo: svcInfo, TracerCtl: &rpcinfo.TraceController{}, ReadWriteTimeout: rwTimeout, @@ -488,7 +480,7 @@ func TestInvokeError(t *testing.T) { return err }, }, - SvcSearchMap: svcSearchMap, + SvcSearcher: svcSearcher, TargetSvcInfo: svcInfo, TracerCtl: &rpcinfo.TraceController{}, ReadWriteTimeout: rwTimeout, @@ -721,7 +713,7 @@ func TestMuxSvrOnReadHeartbeat(t *testing.T) { return err }, }, - SvcSearchMap: svcSearchMap, + SvcSearcher: svcSearcher, TargetSvcInfo: svcInfo, TracerCtl: &rpcinfo.TraceController{}, ReadWriteTimeout: rwTimeout, diff --git a/pkg/remote/trans/nphttp2/mocks_test.go b/pkg/remote/trans/nphttp2/mocks_test.go index e027235869..f77195dcd2 100644 --- a/pkg/remote/trans/nphttp2/mocks_test.go +++ b/pkg/remote/trans/nphttp2/mocks_test.go @@ -296,7 +296,7 @@ func newMockConnOption() remote.ConnOption { func newMockServerOption() *remote.ServerOption { return &remote.ServerOption{ - SvcSearchMap: nil, + SvcSearcher: nil, TransServerFactory: nil, SvrHandlerFactory: nil, Codec: nil, diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index dd99c38915..c3b4d3b6b6 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -59,19 +59,19 @@ func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remo func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) { return &svrTransHandler{ - opt: opt, - svcSearchMap: opt.SvcSearchMap, - codec: grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)), + opt: opt, + svcSearcher: opt.SvcSearcher, + codec: grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)), }, nil } var _ remote.ServerTransHandler = &svrTransHandler{} type svrTransHandler struct { - opt *remote.ServerOption - svcSearchMap map[string]*serviceinfo.ServiceInfo - inkHdlFunc endpoint.Endpoint - codec remote.Codec + opt *remote.ServerOption + svcSearcher remote.ServiceSearcher + inkHdlFunc endpoint.Endpoint + codec remote.Codec } var prefaceReadAtMost = func() int { @@ -193,7 +193,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { // set send grpc compressor at server to encode reply pack remote.SetSendCompressor(ri, s.SendCompress()) - svcInfo := t.svcSearchMap[remote.BuildMultiServiceKey(serviceName, methodName)] + svcInfo := t.svcSearcher.SearchService(serviceName, methodName, true) var methodInfo serviceinfo.MethodInfo if svcInfo != nil { methodInfo = svcInfo.MethodInfo(methodName) diff --git a/pkg/transmeta/ttheader.go b/pkg/transmeta/ttheader.go index c8ce9e1c24..8c1bb81325 100644 --- a/pkg/transmeta/ttheader.go +++ b/pkg/transmeta/ttheader.go @@ -79,16 +79,31 @@ func (ch *clientTTHeaderHandler) WriteMeta(ctx context.Context, msg remote.Messa hd[transmeta.ConnectTimeout] = strconv.Itoa(int(ri.Config().ConnectTimeout().Milliseconds())) } transInfo.PutTransIntInfo(hd) + if idlSvcName := getIDLSvcName(msg.ServiceInfo(), ri); idlSvcName != "" { + if strInfo := transInfo.TransStrInfo(); strInfo != nil { + strInfo[transmeta.HeaderIDLServiceName] = idlSvcName + } else { + transInfo.PutTransStrInfo(map[string]string{transmeta.HeaderIDLServiceName: idlSvcName}) + } + } + return ctx, nil +} + +func getIDLSvcName(svcInfo *serviceinfo.ServiceInfo, ri rpcinfo.RPCInfo) string { idlSvcName := ri.Invocation().ServiceName() + // for combine service, idlSvcName may not be the same as server's service name var isCombineService bool - val, exists := msg.ServiceInfo().Extra["combine_service"] - if exists { - isCombineService, _ = val.(bool) + if svcInfo != nil { + val, exists := svcInfo.Extra["combine_service"] + if exists { + isCombineService, _ = val.(bool) + } } + // generic service name shouldn't be written to header if idlSvcName != serviceinfo.GenericService && !isCombineService { - transInfo.PutTransStrInfo(map[string]string{transmeta.HeaderIDLServiceName: idlSvcName}) + return idlSvcName } - return ctx, nil + return "" } // ReadMeta of clientTTHeaderHandler reads headers of TTHeader protocol from transport diff --git a/pkg/transmeta/ttheader_test.go b/pkg/transmeta/ttheader_test.go index 1bf0453534..54608d18f4 100644 --- a/pkg/transmeta/ttheader_test.go +++ b/pkg/transmeta/ttheader_test.go @@ -60,6 +60,7 @@ func TestTTHeaderClientWriteMetainfo(t *testing.T) { toInfo := rpcinfo.NewEndpointInfo("toServiceName", "toMethod", nil, nil) ri := rpcinfo.NewRPCInfo(fromInfo, toInfo, rpcinfo.NewInvocation("", ""), cfg, rpcinfo.NewRPCStats()) msg := remote.NewMessage(nil, mocks.ServiceInfo(), ri, remote.Call, remote.Client) + ri.Invocation().(rpcinfo.InvocationSetter).SetServiceName(msg.ServiceInfo().ServiceName) // pure paylod, no effect msg.SetProtocolInfo(remote.NewProtocolInfo(transport.PurePayload, serviceinfo.Thrift)) @@ -86,7 +87,7 @@ func TestTTHeaderClientWriteMetainfo(t *testing.T) { test.Assert(t, kvs[transmeta.ConnectTimeout] == "1000") strKvs = msg.TransInfo().TransStrInfo() test.Assert(t, len(strKvs) == 1) - test.Assert(t, strKvs[transmeta.HeaderIDLServiceName] == "") + test.Assert(t, strKvs[transmeta.HeaderIDLServiceName] == msg.ServiceInfo().ServiceName) } func TestTTHeaderServerReadMetainfo(t *testing.T) { diff --git a/server/option_advanced_test.go b/server/option_advanced_test.go index 565e21352a..75ae1f64c1 100644 --- a/server/option_advanced_test.go +++ b/server/option_advanced_test.go @@ -245,7 +245,8 @@ func TestWithSupportedTransportsFunc(t *testing.T) { svcInfo := mocks.ServiceInfo() svr.RegisterService(svcInfo, new(mockImpl)) svr.(*server).fillMoreServiceInfo(nil) - test.Assert(t, reflect.DeepEqual(svr.GetServiceInfos()[remote.BuildMultiServiceKey(svcInfo.ServiceName, mocks.MockMethod)].Extra["transports"], tcase.wantTransports)) + svcInfo = svr.(*server).svcs.SearchService(svcInfo.ServiceName, mocks.MockMethod, false) + test.Assert(t, reflect.DeepEqual(svcInfo.Extra["transports"], tcase.wantTransports)) } } diff --git a/server/option_test.go b/server/option_test.go index 1a9609b32e..e0d5302d10 100644 --- a/server/option_test.go +++ b/server/option_test.go @@ -531,17 +531,9 @@ func TestWithProfilerMessageTagging(t *testing.T) { ri := rpcinfo.NewRPCInfo(from, to, nil, nil, nil) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) svcInfo := mocks.ServiceInfo() - svcSearchMap := map[string]*serviceinfo.ServiceInfo{ - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, - mocks.MockMethod: svcInfo, - mocks.MockExceptionMethod: svcInfo, - mocks.MockErrorMethod: svcInfo, - mocks.MockOnewayMethod: svcInfo, - } - msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) + svcSearcher := newServices() + svcSearcher.addService(svcInfo, mocks.MyServiceHandler(), &RegisterOptions{}) + msg := remote.NewMessageWithNewer(svcInfo, svcSearcher, ri, remote.Call, remote.Server) newCtx, tags := iSvr.opt.RemoteOpt.ProfilerMessageTagging(ctx, msg) test.Assert(t, len(tags) == 8) diff --git a/server/server.go b/server/server.go index 07027ede9e..77df8d038d 100644 --- a/server/server.go +++ b/server/server.go @@ -200,7 +200,7 @@ func (s *server) RegisterService(svcInfo *serviceinfo.ServiceInfo, handler inter } func (s *server) GetServiceInfos() map[string]*serviceinfo.ServiceInfo { - return s.svcs.getSvcInfoSearchMap() + return s.svcs.getSvcInfoMap() } // Run runs the server. @@ -346,8 +346,7 @@ func (s *server) invokeHandleEndpoint() endpoint.Endpoint { func (s *server) initBasicRemoteOption() { remoteOpt := s.opt.RemoteOpt remoteOpt.TargetSvcInfo = s.targetSvcInfo - remoteOpt.SvcSearchMap = s.svcs.getSvcInfoSearchMap() - remoteOpt.RefuseTrafficWithoutServiceName = s.opt.RefuseTrafficWithoutServiceName + remoteOpt.SvcSearcher = s.svcs remoteOpt.InitOrResetRPCInfoFunc = s.initOrResetRPCInfoFunc() remoteOpt.TracerCtl = s.opt.TracerCtl remoteOpt.ReadWriteTimeout = s.opt.Configs.ReadWriteTimeout() @@ -435,7 +434,16 @@ func (s *server) check() error { if len(s.svcs.svcMap) == 0 { return errors.New("run: no service. Use RegisterService to set one") } - return checkFallbackServiceForConflictingMethods(s.svcs.conflictingMethodHasFallbackSvcMap, s.opt.RefuseTrafficWithoutServiceName) + if s.opt.RefuseTrafficWithoutServiceName { + s.svcs.refuseTrafficWithoutServiceName = true + return nil + } + for name, conflict := range s.svcs.conflictingMethodMap { + if conflict { + return fmt.Errorf("method name [%s] is conflicted between services but no fallback service is specified", name) + } + } + return nil } func doAddBoundHandlerToHead(h remote.BoundHandler, opt *remote.ServerOption) { @@ -566,15 +574,3 @@ func getDefaultSvcInfo(svcs *services) *serviceinfo.ServiceInfo { } return nil } - -func checkFallbackServiceForConflictingMethods(conflictingMethodHasFallbackSvcMap map[string]bool, refuseTrafficWithoutServiceName bool) error { - if refuseTrafficWithoutServiceName { - return nil - } - for name, hasFallbackSvc := range conflictingMethodHasFallbackSvcMap { - if !hasFallbackSvc { - return fmt.Errorf("method name [%s] is conflicted between services but no fallback service is specified", name) - } - } - return nil -} diff --git a/server/server_test.go b/server/server_test.go index 3013aa7a0e..e9b71f5e94 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -53,19 +53,7 @@ import ( "github.com/cloudwego/kitex/transport" ) -var ( - svcInfo = mocks.ServiceInfo() - svcSearchMap = map[string]*serviceinfo.ServiceInfo{ - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, - remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, - mocks.MockMethod: svcInfo, - mocks.MockExceptionMethod: svcInfo, - mocks.MockErrorMethod: svcInfo, - mocks.MockOnewayMethod: svcInfo, - } -) +var svcInfo = mocks.ServiceInfo() func TestServerRun(t *testing.T) { var opts []Option @@ -529,8 +517,9 @@ func TestGRPCServerMultipleServices(t *testing.T) { test.Assert(t, err == nil) err = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler()) test.Assert(t, err == nil) - test.DeepEqual(t, svr.GetServiceInfos()[mocks.MockMethod], mocks.ServiceInfo()) - test.DeepEqual(t, svr.GetServiceInfos()[mocks.Mock2Method], mocks.Service2Info()) + + test.DeepEqual(t, svr.(*server).svcs.SearchService("", mocks.MockMethod, false), mocks.ServiceInfo()) + test.DeepEqual(t, svr.(*server).svcs.SearchService("", mocks.Mock2Method, false), mocks.Service2Info()) time.AfterFunc(1000*time.Millisecond, func() { err := svr.Stop() test.Assert(t, err == nil, err) @@ -741,7 +730,9 @@ func testInvokeHandlerWithSession(t *testing.T, fail bool, ad string) { { // mock server call ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) + svcSearcher := newServices() + svcSearcher.addService(svcInfo, mocks.MyServiceHandler(), &RegisterOptions{}) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearcher, ri, remote.Call, remote.Server) recvMsg.NewData(callMethod) sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) @@ -828,7 +819,9 @@ func TestInvokeHandlerExec(t *testing.T) { { // mock server call ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) + svcSearcher := newServices() + svcSearcher.addService(svcInfo, mocks.MyServiceHandler(), &RegisterOptions{}) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearcher, ri, remote.Call, remote.Server) recvMsg.NewData(callMethod) sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) @@ -891,7 +884,9 @@ func TestInvokeHandlerPanic(t *testing.T) { // mock server call ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) + svcSearcher := newServices() + svcSearcher.addService(svcInfo, mocks.MyServiceHandler(), &RegisterOptions{}) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearcher, ri, remote.Call, remote.Server) recvMsg.NewData(callMethod) sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) diff --git a/server/service.go b/server/service.go index 4d679917a2..4ce62014b2 100644 --- a/server/service.go +++ b/server/service.go @@ -17,10 +17,8 @@ package server import ( - "errors" "fmt" - "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/serviceinfo" ) @@ -34,87 +32,50 @@ func newService(svcInfo *serviceinfo.ServiceInfo, handler interface{}) *service } type services struct { - svcSearchMap map[string]*service // key: "svcName.methodName" and "methodName", value: svcInfo - svcMap map[string]*service // key: service name, value: svcInfo - conflictingMethodHasFallbackSvcMap map[string]bool - fallbackSvc *service + methodSvcMap map[string]*service // key: method name, value: svcInfo + svcMap map[string]*service // key: service name, value: svcInfo + conflictingMethodMap map[string]bool + fallbackSvc *service + + refuseTrafficWithoutServiceName bool } func newServices() *services { return &services{ - svcSearchMap: map[string]*service{}, - svcMap: map[string]*service{}, - conflictingMethodHasFallbackSvcMap: map[string]bool{}, + methodSvcMap: map[string]*service{}, + svcMap: map[string]*service{}, + conflictingMethodMap: map[string]bool{}, } } func (s *services) addService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, registerOpts *RegisterOptions) error { svc := newService(svcInfo, handler) - - if err := s.checkCombineServiceWithOtherService(svcInfo); err != nil { - return err - } - - if err := s.checkMultipleFallbackService(registerOpts, svc); err != nil { - return err - } - - s.svcMap[svcInfo.ServiceName] = svc - s.createSearchMap(svcInfo, svc, registerOpts) - return nil -} - -// when registering combine service, it does not allow the registration of other services -func (s *services) checkCombineServiceWithOtherService(svcInfo *serviceinfo.ServiceInfo) error { - if len(s.svcMap) > 0 { - if _, ok := s.svcMap["CombineService"]; ok || svcInfo.ServiceName == "CombineService" { - return errors.New("only one service can be registered when registering combine service") - } - } - return nil -} - -func (s *services) checkMultipleFallbackService(registerOpts *RegisterOptions, svc *service) error { if registerOpts.IsFallbackService { if s.fallbackSvc != nil { return fmt.Errorf("multiple fallback services cannot be registered. [%s] is already registered as a fallback service", s.fallbackSvc.svcInfo.ServiceName) } s.fallbackSvc = svc } - return nil -} - -func (s *services) createSearchMap(svcInfo *serviceinfo.ServiceInfo, svc *service, registerOpts *RegisterOptions) { + s.svcMap[svcInfo.ServiceName] = svc + // method search map for methodName := range svcInfo.Methods { - s.svcSearchMap[remote.BuildMultiServiceKey(svcInfo.ServiceName, methodName)] = svc - if svcFromMap, ok := s.svcSearchMap[methodName]; ok { - s.handleConflictingMethod(svcFromMap, svc, methodName, registerOpts) + if _, ok := s.methodSvcMap[methodName]; ok { + s.handleConflictingMethod(svc, methodName, registerOpts) } else { - s.svcSearchMap[methodName] = svc + s.methodSvcMap[methodName] = svc } } + return nil } -func (s *services) handleConflictingMethod(svcFromMap, svc *service, methodName string, registerOpts *RegisterOptions) { - s.registerConflictingMethodHasFallbackSvcMap(svcFromMap, methodName) - s.updateWithFallbackSvc(registerOpts, svc, methodName) -} - -func (s *services) registerConflictingMethodHasFallbackSvcMap(svcFromMap *service, methodName string) { - if _, ok := s.conflictingMethodHasFallbackSvcMap[methodName]; !ok { - if s.fallbackSvc != nil && svcFromMap.svcInfo.ServiceName == s.fallbackSvc.svcInfo.ServiceName { - // svc which is already registered is a fallback service - s.conflictingMethodHasFallbackSvcMap[methodName] = true - } else { - s.conflictingMethodHasFallbackSvcMap[methodName] = false - } +func (s *services) handleConflictingMethod(svc *service, methodName string, registerOpts *RegisterOptions) { + // true means has conflicting method + if _, ok := s.conflictingMethodMap[methodName]; !ok { + s.conflictingMethodMap[methodName] = true } -} - -func (s *services) updateWithFallbackSvc(registerOpts *RegisterOptions, svc *service, methodName string) { if registerOpts.IsFallbackService { - s.svcSearchMap[methodName] = svc - s.conflictingMethodHasFallbackSvcMap[methodName] = true + s.conflictingMethodMap[methodName] = false + s.methodSvcMap[methodName] = svc } } @@ -126,10 +87,27 @@ func (s *services) getSvcInfoMap() map[string]*serviceinfo.ServiceInfo { return svcInfoMap } -func (s *services) getSvcInfoSearchMap() map[string]*serviceinfo.ServiceInfo { - svcInfoSearchMap := map[string]*serviceinfo.ServiceInfo{} - for name, svc := range s.svcSearchMap { - svcInfoSearchMap[name] = svc.svcInfo +func (s *services) SearchService(svcName, methodName string, strict bool) *serviceinfo.ServiceInfo { + if strict || s.refuseTrafficWithoutServiceName { + if svc := s.svcMap[svcName]; svc != nil { + return svc.svcInfo + } + return nil } - return svcInfoSearchMap + var svc *service + if svcName == "" { + svc = s.methodSvcMap[methodName] + } else { + svc = s.svcMap[svcName] + if svc == nil { + if _, ok := s.conflictingMethodMap[methodName]; !ok { + // no conflicting method + svc = s.methodSvcMap[methodName] + } + } + } + if svc != nil { + return svc.svcInfo + } + return nil } diff --git a/server/service_test.go b/server/service_test.go index 5d800c499c..e5c196f54c 100644 --- a/server/service_test.go +++ b/server/service_test.go @@ -22,7 +22,6 @@ import ( "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/serviceinfo" ) func TestAddService(t *testing.T) { @@ -30,59 +29,31 @@ func TestAddService(t *testing.T) { err := svcs.addService(mocks.ServiceInfo(), mocks.MyServiceHandler(), &RegisterOptions{}) test.Assert(t, err == nil) test.Assert(t, len(svcs.svcMap) == 1) - fmt.Println(svcs.svcSearchMap) - test.Assert(t, len(svcs.svcSearchMap) == 10) - test.Assert(t, len(svcs.conflictingMethodHasFallbackSvcMap) == 0) + fmt.Println(svcs.methodSvcMap) + test.Assert(t, len(svcs.methodSvcMap) == 5) + test.Assert(t, len(svcs.conflictingMethodMap) == 0) test.Assert(t, svcs.fallbackSvc == nil) err = svcs.addService(mocks.Service3Info(), mocks.MyServiceHandler(), &RegisterOptions{IsFallbackService: true}) test.Assert(t, err == nil) test.Assert(t, len(svcs.svcMap) == 2) - test.Assert(t, len(svcs.svcSearchMap) == 11) - test.Assert(t, len(svcs.conflictingMethodHasFallbackSvcMap) == 1) - test.Assert(t, svcs.conflictingMethodHasFallbackSvcMap["mock"]) + test.Assert(t, len(svcs.methodSvcMap) == 5) + test.Assert(t, svcs.SearchService("", "mock", false) == mocks.Service3Info()) + test.Assert(t, svcs.SearchService("", "mock", true) == nil) + svcs.refuseTrafficWithoutServiceName = true + test.Assert(t, svcs.SearchService("", "mock", false) == nil) + test.Assert(t, len(svcs.conflictingMethodMap) == 1) + test.Assert(t, !svcs.conflictingMethodMap["mock"]) err = svcs.addService(mocks.Service2Info(), mocks.MyServiceHandler(), &RegisterOptions{IsFallbackService: true}) test.Assert(t, err != nil) test.Assert(t, err.Error() == "multiple fallback services cannot be registered. [MockService3] is already registered as a fallback service") } -func TestCheckCombineServiceWithOtherService(t *testing.T) { - svcs := newServices() - combineSvcInfo := &serviceinfo.ServiceInfo{ServiceName: "CombineService"} - svcs.svcMap[combineSvcInfo.ServiceName] = newService(combineSvcInfo, nil) - err := svcs.checkCombineServiceWithOtherService(mocks.ServiceInfo()) - test.Assert(t, err != nil) - test.Assert(t, err.Error() == "only one service can be registered when registering combine service") - - svcs = newServices() - svcs.svcMap[mocks.MockServiceName] = newService(mocks.ServiceInfo(), mocks.MyServiceHandler()) - err = svcs.checkCombineServiceWithOtherService(combineSvcInfo) - test.Assert(t, err != nil) - test.Assert(t, err.Error() == "only one service can be registered when registering combine service") -} - func TestCheckMultipleFallbackService(t *testing.T) { svcs := newServices() - svc := newService(mocks.ServiceInfo(), mocks.MyServiceHandler()) - registerOpts := &RegisterOptions{IsFallbackService: true} - err := svcs.checkMultipleFallbackService(registerOpts, svc) - test.Assert(t, err == nil) - test.Assert(t, svcs.fallbackSvc == svc) - - err = svcs.checkMultipleFallbackService(registerOpts, newService(mocks.Service2Info(), nil)) + _ = svcs.addService(mocks.ServiceInfo(), mocks.MyServiceHandler(), &RegisterOptions{IsFallbackService: true}) + err := svcs.addService(mocks.ServiceInfo(), mocks.MyServiceHandler(), &RegisterOptions{IsFallbackService: true}) test.Assert(t, err != nil) test.Assert(t, err.Error() == "multiple fallback services cannot be registered. [MockService] is already registered as a fallback service", err) } - -func TestRegisterConflictingMethodHasFallbackSvcMap(t *testing.T) { - svcs := newServices() - svcFromMap := newService(mocks.ServiceInfo(), mocks.MyServiceHandler()) - svcs.registerConflictingMethodHasFallbackSvcMap(svcFromMap, mocks.MockMethod) - test.Assert(t, !svcs.conflictingMethodHasFallbackSvcMap[mocks.MockMethod]) - - svcs = newServices() - svcs.fallbackSvc = svcFromMap - svcs.registerConflictingMethodHasFallbackSvcMap(svcFromMap, mocks.MockMethod) - test.Assert(t, svcs.conflictingMethodHasFallbackSvcMap[mocks.MockMethod]) -}