From 546548617b16c36370b1db7b986f6bdc00408c33 Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Fri, 8 Dec 2023 15:43:35 +0800 Subject: [PATCH] feat: support triple group/version division for a certain Interface (#2532) * feat: support triple group/version division for a certain Interface * add unit test * import format * fix some comments --- config/provider_config.go | 4 +- protocol/triple/client.go | 18 +- protocol/triple/health/health_test.go | 4 +- .../dubbo3_server/api/greet_service.go | 56 ++++ .../internal/server/api/greet_service.go | 54 +++- protocol/triple/server.go | 103 +++---- protocol/triple/triple.go | 2 +- protocol/triple/triple_protocol/client.go | 13 + protocol/triple/triple_protocol/handler.go | 252 ++++++++++++------ .../triple/triple_protocol/handler_compat.go | 37 ++- .../triple_protocol/handler_stream_compat.go | 48 +++- protocol/triple/triple_protocol/option.go | 32 +++ protocol/triple/triple_protocol/server.go | 179 +++++++++++++ protocol/triple/triple_protocol/triple.go | 5 + protocol/triple/triple_test.go | 112 ++++++-- 15 files changed, 711 insertions(+), 208 deletions(-) create mode 100644 protocol/triple/triple_protocol/server.go diff --git a/config/provider_config.go b/config/provider_config.go index 3a99e9f700..5ca41f33bc 100644 --- a/config/provider_config.go +++ b/config/provider_config.go @@ -27,9 +27,9 @@ import ( "github.com/dubbogo/gost/log/logger" - perrors "github.com/pkg/errors" - tripleConstant "github.com/dubbogo/triple/pkg/common/constant" + + perrors "github.com/pkg/errors" ) import ( diff --git a/protocol/triple/client.go b/protocol/triple/client.go index 87799d291f..18b339e288 100644 --- a/protocol/triple/client.go +++ b/protocol/triple/client.go @@ -119,33 +119,37 @@ func newClientManager(url *common.URL) (*clientManager, error) { // If global trace instance was set, it means trace function enabled. // If not, will return NoopTracer. // tracer := opentracing.GlobalTracer() - var triClientOpts []tri.ClientOption + var cliOpts []tri.ClientOption // set max send and recv msg size maxCallRecvMsgSize := constant.DefaultMaxCallRecvMsgSize if recvMsgSize, err := humanize.ParseBytes(url.GetParam(constant.MaxCallRecvMsgSize, "")); err == nil && recvMsgSize > 0 { maxCallRecvMsgSize = int(recvMsgSize) } - triClientOpts = append(triClientOpts, tri.WithReadMaxBytes(maxCallRecvMsgSize)) + cliOpts = append(cliOpts, tri.WithReadMaxBytes(maxCallRecvMsgSize)) maxCallSendMsgSize := constant.DefaultMaxCallSendMsgSize if sendMsgSize, err := humanize.ParseBytes(url.GetParam(constant.MaxCallSendMsgSize, "")); err == nil && sendMsgSize > 0 { maxCallSendMsgSize = int(sendMsgSize) } - triClientOpts = append(triClientOpts, tri.WithSendMaxBytes(maxCallSendMsgSize)) + cliOpts = append(cliOpts, tri.WithSendMaxBytes(maxCallSendMsgSize)) // set serialization serialization := url.GetParam(constant.SerializationKey, constant.ProtobufSerialization) switch serialization { case constant.ProtobufSerialization: case constant.JSONSerialization: - triClientOpts = append(triClientOpts, tri.WithProtoJSON()) + cliOpts = append(cliOpts, tri.WithProtoJSON()) default: panic(fmt.Sprintf("Unsupported serialization: %s", serialization)) } // set timeout timeout := url.GetParamDuration(constant.TimeoutKey, "") - triClientOpts = append(triClientOpts, tri.WithTimeout(timeout)) + cliOpts = append(cliOpts, tri.WithTimeout(timeout)) + + group := url.GetParam(constant.GroupKey, "") + version := url.GetParam(constant.VersionKey, "") + cliOpts = append(cliOpts, tri.WithGroup(group), tri.WithVersion(version)) // dialOpts = append(dialOpts, // @@ -187,7 +191,7 @@ func newClientManager(url *common.URL) (*clientManager, error) { transport = &http.Transport{ TLSClientConfig: cfg, } - triClientOpts = append(triClientOpts, tri.WithTriple()) + cliOpts = append(cliOpts, tri.WithTriple()) case constant.CallHTTP2: if tlsFlag { transport = &http2.Transport{ @@ -222,7 +226,7 @@ func newClientManager(url *common.URL) (*clientManager, error) { if err != nil { return nil, fmt.Errorf("JoinPath failed for base %s, interface %s, method %s", baseTriURL, url.Interface(), method) } - triClient := tri.NewClient(httpClient, triURL, triClientOpts...) + triClient := tri.NewClient(httpClient, triURL, cliOpts...) triClients[method] = triClient } diff --git a/protocol/triple/health/health_test.go b/protocol/triple/health/health_test.go index d9ee7a4d0c..aeda9546c3 100644 --- a/protocol/triple/health/health_test.go +++ b/protocol/triple/health/health_test.go @@ -22,10 +22,10 @@ import ( "testing" "time" ) + import ( "github.com/stretchr/testify/assert" - // If there is a conflict between the healthCheck of Dubbo and the healthCheck of gRPC, an error will occur. _ "google.golang.org/grpc/health/grpc_health_v1" ) @@ -35,6 +35,8 @@ import ( const testService = "testService" +// If there is a conflict between the healthCheck of Dubbo and the healthCheck of gRPC, an error will occur. + func TestSetServingStatus(t *testing.T) { s := NewServer() s.SetServingStatus(testService, healthpb.HealthCheckResponse_SERVING) diff --git a/protocol/triple/internal/dubbo3_server/api/greet_service.go b/protocol/triple/internal/dubbo3_server/api/greet_service.go index 81043f9921..d3b080e00d 100644 --- a/protocol/triple/internal/dubbo3_server/api/greet_service.go +++ b/protocol/triple/internal/dubbo3_server/api/greet_service.go @@ -84,3 +84,59 @@ func (srv *GreetDubbo3Server) GreetServerStream(req *proto.GreetServerStreamRequ } return nil } + +const ( + GroupVersionIdentifier = "g1v1" +) + +type GreetDubbo3ServerGroup1Version1 struct { + greet.UnimplementedGreetServiceServer +} + +func (srv *GreetDubbo3ServerGroup1Version1) Greet(ctx context.Context, req *proto.GreetRequest) (*proto.GreetResponse, error) { + return &proto.GreetResponse{Greeting: GroupVersionIdentifier + req.Name}, nil +} + +func (srv *GreetDubbo3ServerGroup1Version1) GreetStream(stream greet.GreetService_GreetStreamServer) error { + for { + req, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return fmt.Errorf("dubbo3 Bidistream recv error: %s", err) + } + if err := stream.Send(&proto.GreetStreamResponse{Greeting: GroupVersionIdentifier + req.Name}); err != nil { + return fmt.Errorf("dubbo3 Bidistream send error: %s", err) + } + } + return nil +} + +func (srv *GreetDubbo3ServerGroup1Version1) GreetClientStream(stream greet.GreetService_GreetClientStreamServer) error { + var reqs []string + for { + req, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return fmt.Errorf("dubbo3 ClientStream recv error: %s", err) + } + reqs = append(reqs, GroupVersionIdentifier+req.Name) + } + + resp := &proto.GreetClientStreamResponse{ + Greeting: strings.Join(reqs, ","), + } + return stream.SendAndClose(resp) +} + +func (srv *GreetDubbo3ServerGroup1Version1) GreetServerStream(req *proto.GreetServerStreamRequest, stream greet.GreetService_GreetServerStreamServer) error { + for i := 0; i < 5; i++ { + if err := stream.Send(&proto.GreetServerStreamResponse{Greeting: GroupVersionIdentifier + req.Name}); err != nil { + return fmt.Errorf("dubbo3 ServerStream send error: %s", err) + } + } + return nil +} diff --git a/protocol/triple/internal/server/api/greet_service.go b/protocol/triple/internal/server/api/greet_service.go index 99d6818654..7e0aa504ca 100644 --- a/protocol/triple/internal/server/api/greet_service.go +++ b/protocol/triple/internal/server/api/greet_service.go @@ -29,8 +29,7 @@ import ( triple "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" ) -type GreetTripleServer struct { -} +type GreetTripleServer struct{} func (srv *GreetTripleServer) Greet(ctx context.Context, req *greet.GreetRequest) (*greet.GreetResponse, error) { resp := &greet.GreetResponse{Greeting: req.Name} @@ -76,3 +75,54 @@ func (srv *GreetTripleServer) GreetServerStream(ctx context.Context, req *greet. } return nil } + +const ( + GroupVersionIdentifier = "g1v1" +) + +type GreetTripleServerGroup1Version1 struct{} + +func (g *GreetTripleServerGroup1Version1) Greet(ctx context.Context, req *greet.GreetRequest) (*greet.GreetResponse, error) { + resp := &greet.GreetResponse{Greeting: GroupVersionIdentifier + req.Name} + return resp, nil +} + +func (g *GreetTripleServerGroup1Version1) GreetStream(ctx context.Context, stream greettriple.GreetService_GreetStreamServer) error { + for { + req, err := stream.Recv() + if err != nil { + if triple.IsEnded(err) { + break + } + return fmt.Errorf("triple BidiStream recv error: %s", err) + } + if err := stream.Send(&greet.GreetStreamResponse{Greeting: GroupVersionIdentifier + req.Name}); err != nil { + return fmt.Errorf("triple BidiStream send error: %s", err) + } + } + return nil +} + +func (g *GreetTripleServerGroup1Version1) GreetClientStream(ctx context.Context, stream greettriple.GreetService_GreetClientStreamServer) (*greet.GreetClientStreamResponse, error) { + var reqs []string + for stream.Recv() { + reqs = append(reqs, GroupVersionIdentifier+stream.Msg().Name) + } + if stream.Err() != nil && !triple.IsEnded(stream.Err()) { + return nil, fmt.Errorf("triple ClientStream recv err: %s", stream.Err()) + } + resp := &greet.GreetClientStreamResponse{ + Greeting: strings.Join(reqs, ","), + } + + return resp, nil +} + +func (g *GreetTripleServerGroup1Version1) GreetServerStream(ctx context.Context, req *greet.GreetServerStreamRequest, stream greettriple.GreetService_GreetServerStreamServer) error { + for i := 0; i < 5; i++ { + if err := stream.Send(&greet.GreetServerStreamResponse{Greeting: GroupVersionIdentifier + req.Name}); err != nil { + return fmt.Errorf("triple ServerStream send err: %s", err) + } + } + return nil +} diff --git a/protocol/triple/server.go b/protocol/triple/server.go index 847b050a0e..fd6c6b9758 100644 --- a/protocol/triple/server.go +++ b/protocol/triple/server.go @@ -19,9 +19,7 @@ package triple import ( "context" - "crypto/tls" "fmt" - "net/http" "sync" "time" ) @@ -31,9 +29,6 @@ import ( "github.com/dustin/go-humanize" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" - "google.golang.org/grpc" ) @@ -51,16 +46,15 @@ import ( // Server is TRIPLE server type Server struct { - httpServer *http.Server - handler *http.ServeMux - services map[string]grpc.ServiceInfo - mu sync.RWMutex + triServer *tri.Server + services map[string]grpc.ServiceInfo + mu sync.RWMutex } // NewServer creates a new TRIPLE server func NewServer() *Server { return &Server{ - handler: http.NewServeMux(), + services: make(map[string]grpc.ServiceInfo), } } @@ -68,15 +62,12 @@ func NewServer() *Server { func (s *Server) Start(invoker protocol.Invoker, info *server.ServiceInfo) { var ( addr string - err error URL *common.URL hanOpts []tri.HandlerOption ) URL = invoker.GetURL() addr = URL.Location - srv := &http.Server{ - Addr: addr, - } + s.triServer = tri.NewServer(addr) serialization := URL.GetParam(constant.SerializationKey, constant.ProtobufSerialization) switch serialization { case constant.ProtobufSerialization: @@ -94,7 +85,7 @@ func (s *Server) Start(invoker protocol.Invoker, info *server.ServiceInfo) { // grpc.MaxRecvMsgSize(maxServerRecvMsgSize), // grpc.MaxSendMsgSize(maxServerSendMsgSize), //) - var cfg *tls.Config + //var cfg *tls.Config // todo(DMwangnima): think about a more elegant way to configure tls //tlsConfig := config.GetRootConfig().TLSConfig //if tlsConfig != nil { @@ -113,27 +104,17 @@ func (s *Server) Start(invoker protocol.Invoker, info *server.ServiceInfo) { // todo:// move tls config to handleService hanOpts = getHanOpts(URL) - s.httpServer = srv + if info != nil { + s.handleServiceWithInfo(invoker, info, hanOpts...) + s.saveServiceInfo(info) + } else { + s.compatHandleService(URL, hanOpts...) + } + reflection.Register(s) go func() { - mux := s.handler - if info != nil { - handleServiceWithInfo(invoker, info, mux, hanOpts...) - s.saveServiceInfo(info) - } else { - compatHandleService(URL, mux) - } - // todo: figure it out this process - reflection.Register(s) - // todo: without tls - if cfg == nil { - srv.Handler = h2c.NewHandler(mux, &http2.Server{}) - } else { - srv.Handler = mux - } - - if err = srv.ListenAndServe(); err != nil { - logger.Errorf("server serve failed with err: %v", err) + if runErr := s.triServer.Run(); runErr != nil { + logger.Errorf("server serve failed with err: %v", runErr) } }() } @@ -153,12 +134,11 @@ func (s *Server) RefreshService(invoker protocol.Invoker, info *server.ServiceIn panic(fmt.Sprintf("Unsupported serialization: %s", serialization)) } hanOpts = getHanOpts(URL) - mux := s.handler if info != nil { - handleServiceWithInfo(invoker, info, mux, hanOpts...) + s.handleServiceWithInfo(invoker, info, hanOpts...) s.saveServiceInfo(info) } else { - compatHandleService(URL, mux) + s.compatHandleService(URL, hanOpts...) } } @@ -178,6 +158,10 @@ func getHanOpts(url *common.URL) (hanOpts []tri.HandlerOption) { // todo:// open tracing hanOpts = append(hanOpts, tri.WithInterceptors()) + + group := url.GetParam(constant.GroupKey, "") + version := url.GetParam(constant.VersionKey, "") + hanOpts = append(hanOpts, tri.WithGroup(group), tri.WithVersion(version)) return hanOpts } @@ -214,11 +198,11 @@ func waitTripleExporter(providerServices map[string]*config.ServiceConfig) { } // *Important*, this function is responsible for being compatible with old triple-gen code -// compatHandleService creates handler based on ServiceConfig and provider service. -func compatHandleService(url *common.URL, mux *http.ServeMux, opts ...tri.HandlerOption) { +// compatHandleService registers handler based on ServiceConfig and provider service. +func (s *Server) compatHandleService(url *common.URL, opts ...tri.HandlerOption) { providerServices := config.GetProviderConfig().Services if len(providerServices) == 0 { - panic("Provider service map is null") + logger.Info("Provider service map is null") } //waitTripleExporter(providerServices) for key, providerService := range providerServices { @@ -246,22 +230,18 @@ func compatHandleService(url *common.URL, mux *http.ServeMux, opts ...tri.Handle // inject invoker, it has all invocation logics ds.XXX_SetProxyImpl(invoker) - path, handler := compatBuildHandler(ds, opts...) - mux.Handle(path, handler) + s.compatRegisterHandler(ds, opts...) } } -func compatBuildHandler(svc dubbo3.Dubbo3GrpcService, opts ...tri.HandlerOption) (string, http.Handler) { - mux := http.NewServeMux() +func (s *Server) compatRegisterHandler(svc dubbo3.Dubbo3GrpcService, opts ...tri.HandlerOption) { desc := svc.XXX_ServiceDesc() - basePath := desc.ServiceName // init unary handlers for _, method := range desc.Methods { // please refer to protocol/triple/internal/proto/triple_gen/greettriple for procedure examples // error could be ignored because base is empty string procedure := joinProcedure(desc.ServiceName, method.MethodName) - handler := tri.NewCompatUnaryHandler(procedure, svc, tri.MethodHandler(method.Handler), opts...) - mux.Handle(procedure, handler) + _ = s.triServer.RegisterCompatUnaryHandler(procedure, svc, tri.MethodHandler(method.Handler), opts...) } // init stream handlers @@ -278,22 +258,18 @@ func compatBuildHandler(svc dubbo3.Dubbo3GrpcService, opts ...tri.HandlerOption) case stream.ServerStreams: typ = tri.StreamTypeServer } - handler := tri.NewCompatStreamHandler(procedure, svc, typ, stream.Handler, opts...) - mux.Handle(procedure, handler) + _ = s.triServer.RegisterCompatStreamHandler(procedure, svc, typ, stream.Handler, opts...) } - - return "/" + basePath + "/", mux } // handleServiceWithInfo injects invoker and create handler based on ServiceInfo -func handleServiceWithInfo(invoker protocol.Invoker, info *server.ServiceInfo, mux *http.ServeMux, opts ...tri.HandlerOption) { +func (s *Server) handleServiceWithInfo(invoker protocol.Invoker, info *server.ServiceInfo, opts ...tri.HandlerOption) { for _, method := range info.Methods { m := method - var handler http.Handler procedure := joinProcedure(info.InterfaceName, method.Name) switch m.Type { case constant.CallUnary: - handler = tri.NewUnaryHandler( + _ = s.triServer.RegisterUnaryHandler( procedure, m.ReqInitFunc, func(ctx context.Context, req *tri.Request) (*tri.Response, error) { @@ -308,7 +284,7 @@ func handleServiceWithInfo(invoker protocol.Invoker, info *server.ServiceInfo, m opts..., ) case constant.CallClientStream: - handler = tri.NewClientStreamHandler( + _ = s.triServer.RegisterClientStreamHandler( procedure, func(ctx context.Context, stream *tri.ClientStream) (*tri.Response, error) { var args []interface{} @@ -317,9 +293,10 @@ func handleServiceWithInfo(invoker protocol.Invoker, info *server.ServiceInfo, m res := invoker.Invoke(ctx, invo) return res.Result().(*tri.Response), res.Error() }, + opts..., ) case constant.CallServerStream: - handler = tri.NewServerStreamHandler( + _ = s.triServer.RegisterServerStreamHandler( procedure, m.ReqInitFunc, func(ctx context.Context, request *tri.Request, stream *tri.ServerStream) error { @@ -329,9 +306,10 @@ func handleServiceWithInfo(invoker protocol.Invoker, info *server.ServiceInfo, m res := invoker.Invoke(ctx, invo) return res.Error() }, + opts..., ) case constant.CallBidiStream: - handler = tri.NewBidiStreamHandler( + _ = s.triServer.RegisterBidiStreamHandler( procedure, func(ctx context.Context, stream *tri.BidiStream) error { var args []interface{} @@ -340,9 +318,9 @@ func handleServiceWithInfo(invoker protocol.Invoker, info *server.ServiceInfo, m res := invoker.Invoke(ctx, invo) return res.Error() }, + opts..., ) } - mux.Handle(procedure, handler) } } @@ -371,9 +349,6 @@ func (s *Server) saveServiceInfo(info *server.ServiceInfo) { ret.Metadata = info s.mu.Lock() defer s.mu.Unlock() - if s.services == nil { - s.services = make(map[string]grpc.ServiceInfo) - } s.services[info.InterfaceName] = ret } @@ -389,12 +364,10 @@ func (s *Server) GetServiceInfo() map[string]grpc.ServiceInfo { // Stop TRIPLE server func (s *Server) Stop() { - // todo: process error - s.httpServer.Close() + _ = s.triServer.Stop() } // GracefulStop TRIPLE server func (s *Server) GracefulStop() { - // todo: process error and use timeout - s.httpServer.Shutdown(context.Background()) + _ = s.triServer.GracefulStop(context.Background()) } diff --git a/protocol/triple/triple.go b/protocol/triple/triple.go index de083aafcc..6f41520c5a 100644 --- a/protocol/triple/triple.go +++ b/protocol/triple/triple.go @@ -64,9 +64,9 @@ func (tp *TripleProtocol) Export(invoker protocol.Invoker) protocol.Exporter { } exporter := NewTripleExporter(serviceKey, invoker, tp.ExporterMap()) tp.SetExporterMap(serviceKey, exporter) - health.SetServingStatusServing(url.Service()) logger.Infof("[TRIPLE Protocol] Export service: %s", url.String()) tp.openServer(invoker, info) + health.SetServingStatusServing(url.Service()) return exporter } diff --git a/protocol/triple/triple_protocol/client.go b/protocol/triple/triple_protocol/client.go index 77ebd5bd6e..4d49a6ff27 100644 --- a/protocol/triple/triple_protocol/client.go +++ b/protocol/triple/triple_protocol/client.go @@ -124,6 +124,7 @@ func (c *Client) CallUnary(ctx context.Context, request *Request, response *Resp if flag { defer cancel() } + applyGroupVersionHeaders(request.Header(), c.config) return c.callUnary(ctx, request, response) } @@ -169,6 +170,7 @@ func (c *Client) CallBidiStream(ctx context.Context) (*BidiStreamForClient, erro func (c *Client) newConn(ctx context.Context, streamType StreamType) StreamingClientConn { newConn := func(ctx context.Context, spec Spec) StreamingClientConn { header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing + applyGroupVersionHeaders(header, c.config) c.protocolClient.WriteRequestHeader(streamType, header) return c.protocolClient.NewConn(ctx, spec, header) } @@ -195,6 +197,8 @@ type clientConfig struct { GetUseFallback bool IdempotencyLevel IdempotencyLevel Timeout time.Duration + Group string + Version string } func newClientConfig(rawURL string, options []ClientOption) (*clientConfig, *Error) { @@ -279,3 +283,12 @@ func applyDefaultTimeout(ctx context.Context, timeout time.Duration) (context.Co return ctx, applyFlag, cancel } + +func applyGroupVersionHeaders(header http.Header, cfg *clientConfig) { + if cfg.Group != "" { + header.Set(tripleServiceGroup, cfg.Group) + } + if cfg.Version != "" { + header.Set(tripleServiceVersion, cfg.Version) + } +} diff --git a/protocol/triple/triple_protocol/handler.go b/protocol/triple/triple_protocol/handler.go index 9bca672b57..3c02ee62af 100644 --- a/protocol/triple/triple_protocol/handler.go +++ b/protocol/triple/triple_protocol/handler.go @@ -20,6 +20,10 @@ import ( "net/http" ) +const ( + defaultImplementationsSize = 4 +) + // A Handler is the server-side implementation of a single RPC defined by a // service schema. // @@ -27,8 +31,9 @@ import ( // the binary Protobuf and JSON codecs. They support gzip compression using the // standard library's [compress/gzip]. type Handler struct { - spec Spec - implementation StreamingHandlerFunc + spec Spec + // key is group/version + implementations map[string]StreamingHandlerFunc protocolHandlers []protocolHandler allowMethod string // Allow header acceptPost string // Accept-Post header @@ -41,6 +46,27 @@ func NewUnaryHandler( unary func(context.Context, *Request) (*Response, error), options ...HandlerOption, ) *Handler { + config := newHandlerConfig(procedure, options) + implementation := generateUnaryHandlerFunc(procedure, reqInitFunc, unary, config.Interceptor) + protocolHandlers := config.newProtocolHandlers(StreamTypeUnary) + + hdl := &Handler{ + spec: config.newSpec(StreamTypeUnary), + implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize), + protocolHandlers: protocolHandlers, + allowMethod: sortedAllowMethodValue(protocolHandlers), + acceptPost: sortedAcceptPostValue(protocolHandlers), + } + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + return hdl +} + +func generateUnaryHandlerFunc( + procedure string, + reqInitFunc func() interface{}, + unary func(context.Context, *Request) (*Response, error), + interceptor Interceptor, +) StreamingHandlerFunc { // Wrap the strongly-typed implementation so we can apply interceptors. untyped := UnaryHandlerFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { // verify err @@ -60,8 +86,7 @@ func NewUnaryHandler( return res, err }) // todo: modify server func - config := newHandlerConfig(procedure, options) - if interceptor := config.Interceptor; interceptor != nil { + if interceptor != nil { untyped = interceptor.WrapUnaryHandler(untyped) } // receive and send @@ -90,97 +115,158 @@ func NewUnaryHandler( return conn.Send(response.Any()) } - protocolHandlers := config.newProtocolHandlers(StreamTypeUnary) - return &Handler{ - spec: config.newSpec(StreamTypeUnary), - implementation: implementation, - protocolHandlers: protocolHandlers, - allowMethod: sortedAllowMethodValue(protocolHandlers), - acceptPost: sortedAcceptPostValue(protocolHandlers), - } + return implementation } // NewClientStreamHandler constructs a [Handler] for a client streaming procedure. func NewClientStreamHandler( procedure string, - implementation func(context.Context, *ClientStream) (*Response, error), + streamFunc func(context.Context, *ClientStream) (*Response, error), options ...HandlerOption, ) *Handler { - return newStreamHandler( - procedure, - StreamTypeClient, - func(ctx context.Context, conn StreamingHandlerConn) error { - stream := &ClientStream{conn: conn} - // embed header in context so that user logic could process them via FromIncomingContext - ctx = newIncomingContext(ctx, conn.RequestHeader()) - res, err := implementation(ctx, stream) - if err != nil { - return err - } - if res == nil { - // This is going to panic during serialization. Debugging is much easier - // if we panic here instead, so we can include the procedure name. - panic(fmt.Sprintf("%s returned nil *triple.Response and nil error", procedure)) //nolint: forbidigo - } - mergeHeaders(conn.ResponseHeader(), res.header) - mergeHeaders(conn.ResponseTrailer(), res.trailer) - return conn.Send(res.Msg) - }, - options..., - ) + config := newHandlerConfig(procedure, options) + implementation := generateClientStreamHandlerFunc(procedure, streamFunc, config.Interceptor) + protocolHandlers := config.newProtocolHandlers(StreamTypeClient) + + hdl := &Handler{ + spec: config.newSpec(StreamTypeClient), + implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize), + protocolHandlers: protocolHandlers, + allowMethod: sortedAllowMethodValue(protocolHandlers), + acceptPost: sortedAcceptPostValue(protocolHandlers), + } + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + + return hdl +} + +func generateClientStreamHandlerFunc( + procedure string, + streamFunc func(context.Context, *ClientStream) (*Response, error), + interceptor Interceptor, +) StreamingHandlerFunc { + implementation := func(ctx context.Context, conn StreamingHandlerConn) error { + stream := &ClientStream{conn: conn} + // embed header in context so that user logic could process them via FromIncomingContext + ctx = newIncomingContext(ctx, conn.RequestHeader()) + res, err := streamFunc(ctx, stream) + if err != nil { + return err + } + if res == nil { + // This is going to panic during serialization. Debugging is much easier + // if we panic here instead, so we can include the procedure name. + panic(fmt.Sprintf("%s returned nil *triple.Response and nil error", procedure)) //nolint: forbidigo + } + mergeHeaders(conn.ResponseHeader(), res.header) + mergeHeaders(conn.ResponseTrailer(), res.trailer) + return conn.Send(res.Msg) + } + if interceptor != nil { + implementation = interceptor.WrapStreamingHandler(implementation) + } + + return implementation } // NewServerStreamHandler constructs a [Handler] for a server streaming procedure. func NewServerStreamHandler( procedure string, reqInitFunc func() interface{}, - implementation func(context.Context, *Request, *ServerStream) error, + streamFunc func(context.Context, *Request, *ServerStream) error, options ...HandlerOption, ) *Handler { - return newStreamHandler( - procedure, - StreamTypeServer, - func(ctx context.Context, conn StreamingHandlerConn) error { - req := reqInitFunc() - if err := conn.Receive(req); err != nil { - return err - } - // embed header in context so that user logic could process them via FromIncomingContext - ctx = newIncomingContext(ctx, conn.RequestHeader()) - return implementation( - ctx, - &Request{ - Msg: req, - spec: conn.Spec(), - peer: conn.Peer(), - header: conn.RequestHeader(), - }, - &ServerStream{conn: conn}, - ) - }, - options..., - ) + config := newHandlerConfig(procedure, options) + implementation := generateServerStreamHandlerFunc(procedure, reqInitFunc, streamFunc, config.Interceptor) + protocolHandlers := config.newProtocolHandlers(StreamTypeServer) + + hdl := &Handler{ + spec: config.newSpec(StreamTypeServer), + implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize), + protocolHandlers: protocolHandlers, + allowMethod: sortedAllowMethodValue(protocolHandlers), + acceptPost: sortedAcceptPostValue(protocolHandlers), + } + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + + return hdl +} + +func generateServerStreamHandlerFunc( + procedure string, + reqInitFunc func() interface{}, + streamFunc func(context.Context, *Request, *ServerStream) error, + interceptor Interceptor, +) StreamingHandlerFunc { + implementation := func(ctx context.Context, conn StreamingHandlerConn) error { + req := reqInitFunc() + if err := conn.Receive(req); err != nil { + return err + } + // embed header in context so that user logic could process them via FromIncomingContext + ctx = newIncomingContext(ctx, conn.RequestHeader()) + return streamFunc( + ctx, + &Request{ + Msg: req, + spec: conn.Spec(), + peer: conn.Peer(), + header: conn.RequestHeader(), + }, + &ServerStream{conn: conn}, + ) + } + if interceptor != nil { + implementation = interceptor.WrapStreamingHandler(implementation) + } + + return implementation } // NewBidiStreamHandler constructs a [Handler] for a bidirectional streaming procedure. func NewBidiStreamHandler( procedure string, - implementation func(context.Context, *BidiStream) error, + streamFunc func(context.Context, *BidiStream) error, options ...HandlerOption, ) *Handler { - return newStreamHandler( - procedure, - StreamTypeBidi, - func(ctx context.Context, conn StreamingHandlerConn) error { - // embed header in context so that user logic could process them via FromIncomingContext - ctx = newIncomingContext(ctx, conn.RequestHeader()) - return implementation( - ctx, - &BidiStream{conn: conn}, - ) - }, - options..., - ) + config := newHandlerConfig(procedure, options) + implementation := generateBidiStreamHandlerFunc(procedure, streamFunc, config.Interceptor) + protocolHandlers := config.newProtocolHandlers(StreamTypeBidi) + + hdl := &Handler{ + spec: config.newSpec(StreamTypeBidi), + implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize), + protocolHandlers: protocolHandlers, + allowMethod: sortedAllowMethodValue(protocolHandlers), + acceptPost: sortedAcceptPostValue(protocolHandlers), + } + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + + return hdl +} + +func generateBidiStreamHandlerFunc( + procedure string, + streamFunc func(context.Context, *BidiStream) error, + interceptor Interceptor, +) StreamingHandlerFunc { + implementation := func(ctx context.Context, conn StreamingHandlerConn) error { + // embed header in context so that user logic could process them via FromIncomingContext + ctx = newIncomingContext(ctx, conn.RequestHeader()) + return streamFunc( + ctx, + &BidiStream{conn: conn}, + ) + } + if interceptor != nil { + implementation = interceptor.WrapStreamingHandler(implementation) + } + + return implementation +} + +func (h *Handler) processImplementation(identifier string, implementation StreamingHandlerFunc) { + h.implementations[identifier] = implementation } // ServeHTTP implements [http.Handler]. @@ -255,7 +341,11 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re return } // invoke implementation - _ = connCloser.Close(h.implementation(ctx, connCloser)) + svcGroup := request.Header.Get(tripleServiceGroup) + svcVersion := request.Header.Get(tripleServiceVersion) + implementation := h.implementations[getIdentifier(svcGroup, svcVersion)] + // todo(DMwangnima): inspect ok + _ = connCloser.Close(implementation(ctx, connCloser)) } type handlerConfig struct { @@ -271,6 +361,8 @@ type handlerConfig struct { BufferPool *bufferPool ReadMaxBytes int SendMaxBytes int + Group string + Version string } func newHandlerConfig(procedure string, options []HandlerOption) *handlerConfig { @@ -333,6 +425,10 @@ func (c *handlerConfig) newProtocolHandlers(streamType StreamType) []protocolHan return handlers } +func getIdentifier(group, version string) string { + return group + "/" + version +} + func newStreamHandler( procedure string, streamType StreamType, @@ -344,11 +440,15 @@ func newStreamHandler( implementation = ic.WrapStreamingHandler(implementation) } protocolHandlers := config.newProtocolHandlers(streamType) - return &Handler{ + + hdl := &Handler{ spec: config.newSpec(streamType), - implementation: implementation, + implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize), protocolHandlers: protocolHandlers, allowMethod: sortedAllowMethodValue(protocolHandlers), acceptPost: sortedAcceptPostValue(protocolHandlers), } + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + + return hdl } diff --git a/protocol/triple/triple_protocol/handler_compat.go b/protocol/triple/triple_protocol/handler_compat.go index 5730abce6a..7c21993f9d 100644 --- a/protocol/triple/triple_protocol/handler_compat.go +++ b/protocol/triple/triple_protocol/handler_compat.go @@ -84,14 +84,34 @@ func NewCompatUnaryHandler( options ...HandlerOption, ) *Handler { config := newHandlerConfig(procedure, options) + implementation := generateCompatUnaryHandlerFunc(procedure, srv, unary, config.Interceptor) + protocolHandlers := config.newProtocolHandlers(StreamTypeUnary) + + hdl := &Handler{ + spec: config.newSpec(StreamTypeUnary), + implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize), + protocolHandlers: protocolHandlers, + allowMethod: sortedAllowMethodValue(protocolHandlers), + acceptPost: sortedAcceptPostValue(protocolHandlers), + } + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) - implementation := func(ctx context.Context, conn StreamingHandlerConn) error { + return hdl +} + +func generateCompatUnaryHandlerFunc( + procedure string, + srv interface{}, + unary MethodHandler, + interceptor Interceptor, +) StreamingHandlerFunc { + return func(ctx context.Context, conn StreamingHandlerConn) error { compatInterceptor := &tripleCompatInterceptor{ spec: conn.Spec(), peer: conn.Peer(), header: conn.RequestHeader(), - procedure: config.Procedure, - interceptor: config.Interceptor, + procedure: procedure, + interceptor: interceptor, } decodeFunc := func(req interface{}) error { if err := conn.Receive(req); err != nil { @@ -101,7 +121,7 @@ func NewCompatUnaryHandler( } // staticcheck error: SA1029. Stub code generated by protoc-gen-go-triple makes use of "XXX_TRIPLE_GO_INTERFACE_NAME" directly //nolint:staticcheck - ctx = context.WithValue(ctx, "XXX_TRIPLE_GO_INTERFACE_NAME", config.Procedure) + ctx = context.WithValue(ctx, "XXX_TRIPLE_GO_INTERFACE_NAME", procedure) respRaw, err := unary(srv, ctx, decodeFunc, compatInterceptor.compatUnaryServerInterceptor) if err != nil { return err @@ -112,13 +132,4 @@ func NewCompatUnaryHandler( mergeHeaders(conn.ResponseTrailer(), resp.Trailer()) return conn.Send(resp.Any()) } - - protocolHandlers := config.newProtocolHandlers(StreamTypeUnary) - return &Handler{ - spec: config.newSpec(StreamTypeUnary), - implementation: implementation, - protocolHandlers: protocolHandlers, - allowMethod: sortedAllowMethodValue(protocolHandlers), - acceptPost: sortedAcceptPostValue(protocolHandlers), - } } diff --git a/protocol/triple/triple_protocol/handler_stream_compat.go b/protocol/triple/triple_protocol/handler_stream_compat.go index 9cac7b572a..bf969928be 100644 --- a/protocol/triple/triple_protocol/handler_stream_compat.go +++ b/protocol/triple/triple_protocol/handler_stream_compat.go @@ -62,19 +62,41 @@ func NewCompatStreamHandler( procedure string, srv interface{}, typ StreamType, - implementation func(srv interface{}, stream grpc.ServerStream) error, + streamFunc func(srv interface{}, stream grpc.ServerStream) error, options ...HandlerOption, ) *Handler { - return newStreamHandler( - procedure, - typ, - func(ctx context.Context, conn StreamingHandlerConn) error { - stream := &compatHandlerStream{ - ctx: ctx, - conn: conn, - } - return implementation(srv, stream) - }, - options..., - ) + config := newHandlerConfig(procedure, options) + implementation := generateCompatStreamHandlerFunc(procedure, srv, streamFunc, config.Interceptor) + protocolHandlers := config.newProtocolHandlers(typ) + + hdl := &Handler{ + spec: config.newSpec(typ), + implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize), + protocolHandlers: protocolHandlers, + allowMethod: sortedAllowMethodValue(protocolHandlers), + acceptPost: sortedAcceptPostValue(protocolHandlers), + } + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + + return hdl +} + +func generateCompatStreamHandlerFunc( + procedure string, + srv interface{}, + streamFunc func(interface{}, grpc.ServerStream) error, + interceptor Interceptor, +) StreamingHandlerFunc { + implementation := func(ctx context.Context, conn StreamingHandlerConn) error { + stream := &compatHandlerStream{ + ctx: ctx, + conn: conn, + } + return streamFunc(srv, stream) + } + if interceptor != nil { + implementation = interceptor.WrapStreamingHandler(implementation) + } + + return implementation } diff --git a/protocol/triple/triple_protocol/option.go b/protocol/triple/triple_protocol/option.go index dab116074a..5cef7d1ae3 100644 --- a/protocol/triple/triple_protocol/option.go +++ b/protocol/triple/triple_protocol/option.go @@ -169,6 +169,14 @@ func WithRequireTripleProtocolHeader() HandlerOption { return &requireTripleProtocolHeaderOption{} } +func WithGroup(group string) Option { + return &groupOption{group} +} + +func WithVersion(version string) Option { + return &versionOption{version} +} + // Option implements both [ClientOption] and [HandlerOption], so it can be // applied both client-side and server-side. type Option interface { @@ -418,6 +426,30 @@ func (o *requireTripleProtocolHeaderOption) applyToHandler(config *handlerConfig config.RequireTripleProtocolHeader = true } +type groupOption struct { + Group string +} + +func (o *groupOption) applyToClient(config *clientConfig) { + config.Group = o.Group +} + +func (o *groupOption) applyToHandler(config *handlerConfig) { + config.Group = o.Group +} + +type versionOption struct { + Version string +} + +func (o *versionOption) applyToClient(config *clientConfig) { + config.Version = o.Version +} + +func (o *versionOption) applyToHandler(config *handlerConfig) { + config.Version = o.Version +} + type idempotencyOption struct { idempotencyLevel IdempotencyLevel } diff --git a/protocol/triple/triple_protocol/server.go b/protocol/triple/triple_protocol/server.go new file mode 100644 index 0000000000..0c09c7bd9d --- /dev/null +++ b/protocol/triple/triple_protocol/server.go @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package triple_protocol + +import ( + "context" + "net/http" + "sync" +) + +import ( + "github.com/dubbogo/grpc-go" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +type Server struct { + mu sync.Mutex + handlers map[string]*Handler + httpSrv *http.Server +} + +func (s *Server) RegisterUnaryHandler( + procedure string, + reqInitFunc func() interface{}, + unary func(context.Context, *Request) (*Response, error), + options ...HandlerOption, +) error { + hdl, ok := s.handlers[procedure] + if !ok { + hdl = NewUnaryHandler(procedure, reqInitFunc, unary, options...) + s.handlers[procedure] = hdl + } else { + config := newHandlerConfig(procedure, options) + implementation := generateUnaryHandlerFunc(procedure, reqInitFunc, unary, config.Interceptor) + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + } + + return nil +} + +func (s *Server) RegisterClientStreamHandler( + procedure string, + stream func(context.Context, *ClientStream) (*Response, error), + options ...HandlerOption, +) error { + hdl, ok := s.handlers[procedure] + if !ok { + hdl = NewClientStreamHandler(procedure, stream, options...) + s.handlers[procedure] = hdl + } else { + config := newHandlerConfig(procedure, options) + implementation := generateClientStreamHandlerFunc(procedure, stream, config.Interceptor) + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + } + + return nil +} + +func (s *Server) RegisterServerStreamHandler( + procedure string, + reqInitFunc func() interface{}, + stream func(context.Context, *Request, *ServerStream) error, + options ...HandlerOption, +) error { + hdl, ok := s.handlers[procedure] + if !ok { + hdl = NewServerStreamHandler(procedure, reqInitFunc, stream, options...) + s.handlers[procedure] = hdl + } else { + config := newHandlerConfig(procedure, options) + implementation := generateServerStreamHandlerFunc(procedure, reqInitFunc, stream, config.Interceptor) + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + } + + return nil +} + +func (s *Server) RegisterBidiStreamHandler( + procedure string, + stream func(context.Context, *BidiStream) error, + options ...HandlerOption, +) error { + hdl, ok := s.handlers[procedure] + if !ok { + hdl = NewBidiStreamHandler(procedure, stream, options...) + s.handlers[procedure] = hdl + } else { + config := newHandlerConfig(procedure, options) + implementation := generateBidiStreamHandlerFunc(procedure, stream, config.Interceptor) + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + } + + return nil +} + +func (s *Server) RegisterCompatUnaryHandler( + procedure string, + srv interface{}, + unary MethodHandler, + options ...HandlerOption, +) error { + hdl, ok := s.handlers[procedure] + if !ok { + hdl = NewCompatUnaryHandler(procedure, srv, unary, options...) + s.handlers[procedure] = hdl + } else { + config := newHandlerConfig(procedure, options) + implementation := generateCompatUnaryHandlerFunc(procedure, srv, unary, config.Interceptor) + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + } + + return nil +} + +func (s *Server) RegisterCompatStreamHandler( + procedure string, + srv interface{}, + typ StreamType, + streamFunc func(srv interface{}, stream grpc.ServerStream) error, + options ...HandlerOption, +) error { + hdl, ok := s.handlers[procedure] + if !ok { + hdl = NewCompatStreamHandler(procedure, srv, typ, streamFunc, options...) + s.handlers[procedure] = hdl + } else { + config := newHandlerConfig(procedure, options) + implementation := generateCompatStreamHandlerFunc(procedure, srv, streamFunc, config.Interceptor) + hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation) + } + + return nil +} + +func (s *Server) Run() error { + mux := http.NewServeMux() + for procedure, hdl := range s.handlers { + mux.Handle(procedure, hdl) + } + // todo(DMwangnima): deal with TLS + s.httpSrv.Handler = h2c.NewHandler(mux, &http2.Server{}) + + if err := s.httpSrv.ListenAndServe(); err != nil { + return err + } + return nil +} + +func (s *Server) Stop() error { + return s.httpSrv.Close() +} + +func (s *Server) GracefulStop(ctx context.Context) error { + return s.httpSrv.Shutdown(ctx) +} + +func NewServer(addr string) *Server { + return &Server{ + handlers: make(map[string]*Handler), + httpSrv: &http.Server{Addr: addr}, + } +} diff --git a/protocol/triple/triple_protocol/triple.go b/protocol/triple/triple_protocol/triple.go index 251ad8fa89..e56dd4b751 100644 --- a/protocol/triple/triple_protocol/triple.go +++ b/protocol/triple/triple_protocol/triple.go @@ -54,6 +54,11 @@ const ( StreamTypeBidi = StreamTypeClient | StreamTypeServer ) +const ( + tripleServiceGroup = "tri-service-group" + tripleServiceVersion = "tri-service-version" +) + // StreamingHandlerConn is the server's view of a bidirectional message // exchange. Interceptors for streaming RPCs may wrap StreamingHandlerConns. // diff --git a/protocol/triple/triple_test.go b/protocol/triple/triple_test.go index b8b7c72fd8..acaf10c6f8 100644 --- a/protocol/triple/triple_test.go +++ b/protocol/triple/triple_test.go @@ -59,6 +59,8 @@ const ( listenAddr = "0.0.0.0" localAddr = "127.0.0.1" name = "triple" + group = "g1" + version = "v1" ) type tripleInvoker struct { @@ -96,12 +98,14 @@ func (t *tripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocati panic(fmt.Sprintf("no match method for %s", name)) } -func runTripleServer(interfaceName string, addr string, info *server.ServiceInfo, handler interface{}) { +func runTripleServer(interfaceName string, group string, version string, addr string, info *server.ServiceInfo, handler interface{}) { url := common.NewURLWithOptions( common.WithPath(interfaceName), common.WithLocation(addr), common.WithPort(triplePort), ) + url.SetParam(constant.GroupKey, group) + url.SetParam(constant.VersionKey, version) var invoker protocol.Invoker if info != nil { invoker = &tripleInvoker{ @@ -114,28 +118,32 @@ func runTripleServer(interfaceName string, addr string, info *server.ServiceInfo GetProtocol().(*TripleProtocol).exportForTest(invoker, info) } -func runOldTripleServer(addr string, desc *grpc_go.ServiceDesc) { +func runOldTripleServer(interfaceName string, group string, version string, addr string, desc *grpc_go.ServiceDesc, svc common.RPCService) { url := common.NewURLWithOptions( + // todo(DMwangnima): figure this out common.WithPath(desc.ServiceName), common.WithLocation(addr), common.WithPort(dubbo3Port), common.WithProtocol(TRIPLE), - common.WithInterface(desc.ServiceName), + common.WithInterface(interfaceName), ) - srv := new(dubbo3_api.GreetDubbo3Server) + url.SetParam(constant.GroupKey, group) + url.SetParam(constant.VersionKey, version) // todo(DMwangnima): add protocol config config.SetRootConfig( *config.NewRootConfigBuilder(). SetProvider( config.NewProviderConfigBuilder(). - AddService(common.GetReference(srv), config.NewServiceConfigBuilder(). - SetInterface(desc.ServiceName). + AddService(common.GetReference(svc), config.NewServiceConfigBuilder(). + SetInterface(interfaceName). + SetGroup(group). + SetVersion(version). Build()). SetProxyFactory("default"). Build()). Build()) - config.SetProviderService(srv) - common.ServiceMap.Register(desc.ServiceName, TRIPLE, "", "", srv) + config.SetProviderService(svc) + common.ServiceMap.Register(desc.ServiceName, TRIPLE, group, version, svc) invoker := extension.GetProxyFactory("default").GetInvoker(url) GetProtocol().(*TripleProtocol).exportForTest(invoker, nil) } @@ -143,13 +151,35 @@ func runOldTripleServer(addr string, desc *grpc_go.ServiceDesc) { func TestMain(m *testing.M) { runTripleServer( greettriple.GreetServiceName, + "", + "", listenAddr, &greettriple.GreetService_ServiceInfo, new(api.GreetTripleServer), ) + runTripleServer( + greettriple.GreetServiceName, + group, + version, + listenAddr, + &greettriple.GreetService_ServiceInfo, + new(api.GreetTripleServerGroup1Version1), + ) runOldTripleServer( + dubbo3_greet.GreetService_ServiceDesc.ServiceName, + "", + "", listenAddr, &dubbo3_greet.GreetService_ServiceDesc, + new(dubbo3_api.GreetDubbo3Server), + ) + runOldTripleServer( + dubbo3_greet.GreetService_ServiceDesc.ServiceName, + group, + version, + listenAddr, + &dubbo3_greet.GreetService_ServiceDesc, + new(dubbo3_api.GreetDubbo3ServerGroup1Version1), ) time.Sleep(3 * time.Second) m.Run() @@ -157,7 +187,7 @@ func TestMain(m *testing.M) { } func TestInvoke(t *testing.T) { - tripleInvokerInit := func(location string, port string, interfaceName string, methods []string, info *client.ClientInfo) (protocol.Invoker, error) { + tripleInvokerInit := func(location string, port string, interfaceName string, group string, version string, methods []string, info *client.ClientInfo) (protocol.Invoker, error) { newURL := common.NewURLWithOptions( common.WithInterface(interfaceName), common.WithLocation(location), @@ -165,14 +195,18 @@ func TestInvoke(t *testing.T) { common.WithMethods(methods), common.WithAttribute(constant.ClientInfoKey, info), ) + newURL.SetParam(constant.GroupKey, group) + newURL.SetParam(constant.VersionKey, version) return NewTripleInvoker(newURL) } - dubbo3InvokerInit := func(location string, port string, interfaceName string, svc common.RPCService) (protocol.Invoker, error) { + dubbo3InvokerInit := func(location string, port string, interfaceName string, group string, version string, svc common.RPCService) (protocol.Invoker, error) { newURL := common.NewURLWithOptions( common.WithInterface(interfaceName), common.WithLocation(location), common.WithPort(port), ) + newURL.SetParam(constant.GroupKey, group) + newURL.SetParam(constant.VersionKey, version) // dubbo3 needs to retrieve ConsumerService directly config.SetConsumerServiceByInterfaceName(interfaceName, svc) return NewDubbo3Invoker(newURL) @@ -204,7 +238,7 @@ func TestInvoke(t *testing.T) { return reply.Interface() } - invokeTripleCodeFunc := func(t *testing.T, invoker protocol.Invoker) { + invokeTripleCodeFunc := func(t *testing.T, invoker protocol.Invoker, identifier string) { tests := []struct { methodName string callType string @@ -225,7 +259,7 @@ func TestInvoke(t *testing.T) { assert.Nil(t, res.Error()) req := params[0].(*greet.GreetRequest) resp := params[1].(*greet.GreetResponse) - assert.Equal(t, req.Name, resp.Greeting) + assert.Equal(t, identifier+req.Name, resp.Greeting) }, }, { @@ -240,7 +274,7 @@ func TestInvoke(t *testing.T) { var expectRes []string times := 5 for i := 1; i <= times; i++ { - expectRes = append(expectRes, name) + expectRes = append(expectRes, identifier+name) err := stream.Send(&greet.GreetClientStreamRequest{Name: name}) assert.Nil(t, err) } @@ -268,7 +302,7 @@ func TestInvoke(t *testing.T) { for i := 1; i <= times; i++ { for stream.Recv() { assert.Nil(t, stream.Err()) - assert.Equal(t, req.Name, stream.Msg().Greeting) + assert.Equal(t, identifier+req.Name, stream.Msg().Greeting) } assert.True(t, true, errors.Is(stream.Err(), io.EOF)) } @@ -287,7 +321,7 @@ func TestInvoke(t *testing.T) { assert.Nil(t, err) resp, err := stream.Recv() assert.Nil(t, err) - assert.Equal(t, name, resp.Greeting) + assert.Equal(t, identifier+name, resp.Greeting) } assert.Nil(t, stream.CloseRequest()) assert.Nil(t, stream.CloseResponse()) @@ -303,7 +337,7 @@ func TestInvoke(t *testing.T) { }) } } - invokeDubbo3CodeFunc := func(t *testing.T, invoker protocol.Invoker, svc common.RPCService) { + invokeDubbo3CodeFunc := func(t *testing.T, invoker protocol.Invoker, svc common.RPCService, identifier string) { tests := []struct { methodName string params []reflect.Value @@ -320,7 +354,7 @@ func TestInvoke(t *testing.T) { assert.Nil(t, res.Error()) req := Params[0].Interface().(*greet.GreetRequest) resp := res.Result().(*greet.GreetResponse) - assert.Equal(t, req.Name, resp.Greeting) + assert.Equal(t, identifier+req.Name, resp.Greeting) }, }, { @@ -333,7 +367,7 @@ func TestInvoke(t *testing.T) { var expectRes []string times := 5 for i := 1; i <= times; i++ { - expectRes = append(expectRes, name) + expectRes = append(expectRes, identifier+name) err := (*stream).Send(&greet.GreetClientStreamRequest{Name: name}) assert.Nil(t, err) } @@ -359,7 +393,7 @@ func TestInvoke(t *testing.T) { for i := 1; i <= times; i++ { msg, err := (*stream).Recv() assert.Nil(t, err) - assert.Equal(t, req.Name, msg.Greeting) + assert.Equal(t, identifier+req.Name, msg.Greeting) } }, }, @@ -374,7 +408,7 @@ func TestInvoke(t *testing.T) { assert.Nil(t, err) resp, err := (*stream).Recv() assert.Nil(t, err) - assert.Equal(t, name, resp.Greeting) + assert.Equal(t, identifier+name, resp.Greeting) } assert.Nil(t, (*stream).CloseSend()) }, @@ -397,25 +431,47 @@ func TestInvoke(t *testing.T) { } t.Run("triple2triple", func(t *testing.T) { - invoker, err := tripleInvokerInit(localAddr, triplePort, greettriple.GreetService_ClientInfo.InterfaceName, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo) + invoker, err := tripleInvokerInit(localAddr, triplePort, greettriple.GreetService_ClientInfo.InterfaceName, "", "", greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo) assert.Nil(t, err) - invokeTripleCodeFunc(t, invoker) + invokeTripleCodeFunc(t, invoker, "") + }) + t.Run("triple2triple_Group1Version1", func(t *testing.T) { + invoker, err := tripleInvokerInit(localAddr, triplePort, greettriple.GreetService_ClientInfo.InterfaceName, group, version, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo) + assert.Nil(t, err) + invokeTripleCodeFunc(t, invoker, api.GroupVersionIdentifier) }) t.Run("triple2dubbo3", func(t *testing.T) { - invoker, err := tripleInvokerInit(localAddr, dubbo3Port, greettriple.GreetService_ClientInfo.InterfaceName, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo) + invoker, err := tripleInvokerInit(localAddr, dubbo3Port, greettriple.GreetService_ClientInfo.InterfaceName, "", "", greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo) + assert.Nil(t, err) + invokeTripleCodeFunc(t, invoker, "") + }) + t.Run("triple2dubbo3_Group1Version1", func(t *testing.T) { + invoker, err := tripleInvokerInit(localAddr, dubbo3Port, greettriple.GreetService_ClientInfo.InterfaceName, group, version, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo) assert.Nil(t, err) - invokeTripleCodeFunc(t, invoker) + invokeTripleCodeFunc(t, invoker, dubbo3_api.GroupVersionIdentifier) }) t.Run("dubbo32triple", func(t *testing.T) { svc := new(dubbo3_greet.GreetServiceClientImpl) - invoker, err := dubbo3InvokerInit(localAddr, triplePort, dubbo3_greet.GreetService_ServiceDesc.ServiceName, svc) + invoker, err := dubbo3InvokerInit(localAddr, triplePort, dubbo3_greet.GreetService_ServiceDesc.ServiceName, "", "", svc) assert.Nil(t, err) - invokeDubbo3CodeFunc(t, invoker, svc) + invokeDubbo3CodeFunc(t, invoker, svc, "") + }) + t.Run("dubbo32triple_Group1Version1", func(t *testing.T) { + svc := new(dubbo3_greet.GreetServiceClientImpl) + invoker, err := dubbo3InvokerInit(localAddr, triplePort, dubbo3_greet.GreetService_ServiceDesc.ServiceName, group, version, svc) + assert.Nil(t, err) + invokeDubbo3CodeFunc(t, invoker, svc, api.GroupVersionIdentifier) }) t.Run("dubbo32dubbo3", func(t *testing.T) { svc := new(dubbo3_greet.GreetServiceClientImpl) - invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, dubbo3_greet.GreetService_ServiceDesc.ServiceName, svc) + invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, dubbo3_greet.GreetService_ServiceDesc.ServiceName, "", "", svc) + assert.Nil(t, err) + invokeDubbo3CodeFunc(t, invoker, svc, "") + }) + t.Run("dubbo32dubbo3_Group1Version1", func(t *testing.T) { + svc := new(dubbo3_greet.GreetServiceClientImpl) + invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, dubbo3_greet.GreetService_ServiceDesc.ServiceName, group, version, svc) assert.Nil(t, err) - invokeDubbo3CodeFunc(t, invoker, svc) + invokeDubbo3CodeFunc(t, invoker, svc, dubbo3_api.GroupVersionIdentifier) }) }