diff --git a/client/clientBean.go b/client/clientBean.go index 27a168054c0..61a723411ba 100644 --- a/client/clientBean.go +++ b/client/clientBean.go @@ -28,16 +28,12 @@ import ( "sync/atomic" "go.uber.org/yarpc" - "go.uber.org/yarpc/transport/grpc" - "go.uber.org/yarpc/transport/tchannel" "github.com/uber/cadence/client/admin" "github.com/uber/cadence/client/frontend" "github.com/uber/cadence/client/history" "github.com/uber/cadence/client/matching" - "github.com/uber/cadence/common/authorization" "github.com/uber/cadence/common/cluster" - "github.com/uber/cadence/common/rpc" ) type ( @@ -67,7 +63,7 @@ type ( ) // NewClientBean provides a collection of clients -func NewClientBean(factory Factory, dispatcherProvider rpc.DispatcherProvider, clusterMetadata cluster.Metadata) (Bean, error) { +func NewClientBean(factory Factory, dispatcher *yarpc.Dispatcher, clusterMetadata cluster.Metadata) (Bean, error) { historyClient, err := factory.NewHistoryClient() if err != nil { @@ -81,31 +77,7 @@ func NewClientBean(factory Factory, dispatcherProvider rpc.DispatcherProvider, c continue } - var dispatcherOptions *rpc.DispatcherOptions - if info.AuthorizationProvider.Enable { - authProvider, err := authorization.GetAuthProviderClient(info.AuthorizationProvider.PrivateKey) - if err != nil { - return nil, err - } - dispatcherOptions = &rpc.DispatcherOptions{ - AuthProvider: authProvider, - } - } - - var dispatcher *yarpc.Dispatcher - var err error - switch info.RPCTransport { - case tchannel.TransportName: - dispatcher, err = dispatcherProvider.GetTChannel(info.RPCName, info.RPCAddress, dispatcherOptions) - case grpc.TransportName: - dispatcher, err = dispatcherProvider.GetGRPC(info.RPCName, info.RPCAddress, dispatcherOptions) - } - - if err != nil { - return nil, err - } - - clientConfig := dispatcher.ClientConfig(info.RPCName) + clientConfig := dispatcher.ClientConfig(clusterName) adminClient, err := factory.NewAdminClientWithTimeoutAndConfig( clientConfig, diff --git a/cmd/server/cadence/server.go b/cmd/server/cadence/server.go index 3b318687f29..1ad2a7b3255 100644 --- a/cmd/server/cadence/server.go +++ b/cmd/server/cadence/server.go @@ -149,6 +149,10 @@ func (s *server) startService() common.Daemon { if err != nil { log.Fatalf("error creating rpc factory params: %v", err) } + rpcParams.OutboundsBuilder = rpc.CombineOutbounds( + rpcParams.OutboundsBuilder, + rpc.NewCrossDCOutbounds(clusterGroupMetadata.ClusterGroup, rpc.NewDNSPeerChooserFactory(s.cfg.PublicClient.RefreshInterval, params.Logger)), + ) params.RPCFactory = rpc.NewFactory(params.Logger, rpcParams) dispatcher := params.RPCFactory.GetDispatcher() diff --git a/common/resource/resourceImpl.go b/common/resource/resourceImpl.go index 9948545c70a..fa13e16754f 100644 --- a/common/resource/resourceImpl.go +++ b/common/resource/resourceImpl.go @@ -167,7 +167,7 @@ func New( numShards, logger, ), - params.DispatcherProvider, + params.RPCFactory.GetDispatcher(), params.ClusterMetadata, ) if err != nil { diff --git a/common/rpc/middleware.go b/common/rpc/middleware.go index 232b240075e..79ecc836a3a 100644 --- a/common/rpc/middleware.go +++ b/common/rpc/middleware.go @@ -87,3 +87,12 @@ func (m *inboundMetricsMiddleware) Handle(ctx context.Context, req *transport.Re ) return h.Handle(ctx, req, resw) } + +type overrideCallerMiddleware struct { + caller string +} + +func (m *overrideCallerMiddleware) Call(ctx context.Context, request *transport.Request, out transport.UnaryOutbound) (*transport.Response, error) { + request.Caller = m.caller + return out.Call(ctx, request) +} diff --git a/common/rpc/middleware_test.go b/common/rpc/middleware_test.go index 7eec8c13335..7c81910f217 100644 --- a/common/rpc/middleware_test.go +++ b/common/rpc/middleware_test.go @@ -68,6 +68,14 @@ func TestInboundMetricsMiddleware(t *testing.T) { }) } +func TestOverrideCallerMiddleware(t *testing.T) { + m := overrideCallerMiddleware{"x-caller"} + _, err := m.Call(context.Background(), &transport.Request{Caller: "service"}, &fakeOutbound{verify: func(r *transport.Request) { + assert.Equal(t, "x-caller", r.Caller) + }}) + assert.NoError(t, err) +} + type fakeHandler struct { ctx context.Context } diff --git a/common/rpc/outbounds.go b/common/rpc/outbounds.go index 91883aa1ed5..737060a7fa9 100644 --- a/common/rpc/outbounds.go +++ b/common/rpc/outbounds.go @@ -30,6 +30,7 @@ import ( "go.uber.org/multierr" "go.uber.org/yarpc" "go.uber.org/yarpc/api/middleware" + "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/transport/grpc" "go.uber.org/yarpc/transport/tchannel" ) @@ -106,3 +107,57 @@ func (b publicClientOutbound) Build(_ *grpc.Transport, tchannel *tchannel.Transp }, }, nil } + +type crossDCOutbounds struct { + clusterGroup map[string]config.ClusterInformation + pcf PeerChooserFactory +} + +func NewCrossDCOutbounds(clusterGroup map[string]config.ClusterInformation, pcf PeerChooserFactory) OutboundsBuilder { + return crossDCOutbounds{clusterGroup, pcf} +} + +func (b crossDCOutbounds) Build(grpcTransport *grpc.Transport, tchannelTransport *tchannel.Transport) (yarpc.Outbounds, error) { + outbounds := yarpc.Outbounds{} + for clusterName, clusterInfo := range b.clusterGroup { + if !clusterInfo.Enabled { + continue + } + + var outbound transport.UnaryOutbound + switch clusterInfo.RPCTransport { + case tchannel.TransportName: + peerChooser, err := b.pcf.CreatePeerChooser(tchannelTransport, clusterInfo.RPCAddress) + if err != nil { + return nil, err + } + outbound = tchannelTransport.NewOutbound(peerChooser) + case grpc.TransportName: + peerChooser, err := b.pcf.CreatePeerChooser(grpcTransport, clusterInfo.RPCAddress) + if err != nil { + return nil, err + } + outbound = grpcTransport.NewOutbound(peerChooser) + default: + return nil, fmt.Errorf("unknown cross DC transport type: %s", clusterInfo.RPCTransport) + } + + var authMiddleware middleware.UnaryOutbound + if clusterInfo.AuthorizationProvider.Enable { + authProvider, err := authorization.GetAuthProviderClient(clusterInfo.AuthorizationProvider.PrivateKey) + if err != nil { + return nil, fmt.Errorf("create AuthProvider: %v", err) + } + authMiddleware = &authOutboundMiddleware{authProvider} + } + + outbounds[clusterName] = transport.Outbounds{ + ServiceName: clusterInfo.RPCName, + Unary: middleware.ApplyUnaryOutbound(outbound, yarpc.UnaryOutboundMiddleware( + authMiddleware, + &overrideCallerMiddleware{crossDCCaller}, + )), + } + } + return outbounds, nil +} diff --git a/common/rpc/outbounds_test.go b/common/rpc/outbounds_test.go index 4ae29e0fd5c..0fbd51eb097 100644 --- a/common/rpc/outbounds_test.go +++ b/common/rpc/outbounds_test.go @@ -31,6 +31,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/yarpc" + "go.uber.org/yarpc/api/peer" + "go.uber.org/yarpc/peer/direct" "go.uber.org/yarpc/transport/grpc" "go.uber.org/yarpc/transport/tchannel" ) @@ -119,6 +121,36 @@ func TestPublicClientOutbound(t *testing.T) { assert.NotNil(t, outbounds[OutboundPublicClient].Unary) } +func TestCrossDCOutbounds(t *testing.T) { + grpc := &grpc.Transport{} + tchannel := &tchannel.Transport{} + + clusterGroup := map[string]config.ClusterInformation{ + "cluster-A": {Enabled: true, RPCName: "cadence-frontend", RPCTransport: "invalid"}, + } + _, err := NewCrossDCOutbounds(clusterGroup, &fakePeerChooserFactory{}).Build(grpc, tchannel) + assert.EqualError(t, err, "unknown cross DC transport type: invalid") + + clusterGroup = map[string]config.ClusterInformation{ + "cluster-A": {Enabled: true, RPCName: "cadence-frontend", RPCTransport: "grpc", AuthorizationProvider: config.AuthorizationProvider{Enable: true, PrivateKey: "invalid path"}}, + } + _, err = NewCrossDCOutbounds(clusterGroup, &fakePeerChooserFactory{}).Build(grpc, tchannel) + assert.EqualError(t, err, "create AuthProvider: invalid private key path invalid path") + + clusterGroup = map[string]config.ClusterInformation{ + "cluster-A": {Enabled: true, RPCName: "cadence-frontend", RPCAddress: "address-A", RPCTransport: "grpc", AuthorizationProvider: config.AuthorizationProvider{Enable: true, PrivateKey: tempFile(t, "key")}}, + "cluster-B": {Enabled: true, RPCName: "cadence-frontend", RPCAddress: "address-B", RPCTransport: "tchannel"}, + "cluster-C": {Enabled: false}, + } + outbounds, err := NewCrossDCOutbounds(clusterGroup, &fakePeerChooserFactory{}).Build(grpc, tchannel) + assert.NoError(t, err) + assert.Equal(t, 2, len(outbounds)) + assert.Equal(t, "cadence-frontend", outbounds["cluster-A"].ServiceName) + assert.Equal(t, "cadence-frontend", outbounds["cluster-B"].ServiceName) + assert.NotNil(t, outbounds["cluster-A"].Unary) + assert.NotNil(t, outbounds["cluster-B"].Unary) +} + func tempFile(t *testing.T, content string) string { f, err := ioutil.TempFile("", "") require.NoError(t, err) @@ -140,3 +172,9 @@ type fakeOutboundBuilder struct { func (b fakeOutboundBuilder) Build(*grpc.Transport, *tchannel.Transport) (yarpc.Outbounds, error) { return b.outbounds, b.err } + +type fakePeerChooserFactory struct{} + +func (f fakePeerChooserFactory) CreatePeerChooser(transport peer.Transport, address string) (peer.Chooser, error) { + return direct.New(direct.Configuration{}, transport) +} diff --git a/host/onebox.go b/host/onebox.go index 42359d0c7d3..4ec9c352ff3 100644 --- a/host/onebox.go +++ b/host/onebox.go @@ -398,7 +398,7 @@ func (c *cadenceImpl) startFrontend(hosts map[string][]string, startWG *sync.Wai params.ThrottledLogger = c.logger params.UpdateLoggerWithServiceName(service.Frontend) params.PProfInitializer = newPProfInitializerImpl(c.logger, c.FrontendPProfPort()) - params.RPCFactory = newRPCFactory(service.Frontend, c.FrontendAddress(), c.logger) + params.RPCFactory = newRPCFactory(service.Frontend, c.FrontendAddress(), c.logger, c.clusterMetadata) params.MetricScope = tally.NewTestScope(service.Frontend, make(map[string]string)) params.MembershipFactory = newMembershipFactory(params.Name, hosts) params.ClusterMetadata = c.clusterMetadata @@ -465,7 +465,7 @@ func (c *cadenceImpl) startHistory( params.ThrottledLogger = c.logger params.UpdateLoggerWithServiceName(service.History) params.PProfInitializer = newPProfInitializerImpl(c.logger, pprofPorts[i]) - params.RPCFactory = newRPCFactory(service.History, hostport, c.logger) + params.RPCFactory = newRPCFactory(service.History, hostport, c.logger, c.clusterMetadata) params.MetricScope = tally.NewTestScope(service.History, make(map[string]string)) params.MembershipFactory = newMembershipFactory(params.Name, hosts) params.ClusterMetadata = c.clusterMetadata @@ -535,7 +535,7 @@ func (c *cadenceImpl) startMatching(hosts map[string][]string, startWG *sync.Wai params.ThrottledLogger = c.logger params.UpdateLoggerWithServiceName(service.Matching) params.PProfInitializer = newPProfInitializerImpl(c.logger, c.MatchingPProfPort()) - params.RPCFactory = newRPCFactory(service.Matching, c.MatchingServiceAddress(), c.logger) + params.RPCFactory = newRPCFactory(service.Matching, c.MatchingServiceAddress(), c.logger, c.clusterMetadata) params.MetricScope = tally.NewTestScope(service.Matching, make(map[string]string)) params.MembershipFactory = newMembershipFactory(params.Name, hosts) params.ClusterMetadata = c.clusterMetadata @@ -578,7 +578,7 @@ func (c *cadenceImpl) startWorker(hosts map[string][]string, startWG *sync.WaitG params.ThrottledLogger = c.logger params.UpdateLoggerWithServiceName(service.Worker) params.PProfInitializer = newPProfInitializerImpl(c.logger, c.WorkerPProfPort()) - params.RPCFactory = newRPCFactory(service.Worker, c.WorkerServiceAddress(), c.logger) + params.RPCFactory = newRPCFactory(service.Worker, c.WorkerServiceAddress(), c.logger, c.clusterMetadata) params.MetricScope = tally.NewTestScope(service.Worker, make(map[string]string)) params.MembershipFactory = newMembershipFactory(params.Name, hosts) params.ClusterMetadata = c.clusterMetadata @@ -791,7 +791,7 @@ func newPProfInitializerImpl(logger log.Logger, port int) common.PProfInitialize } } -func newRPCFactory(serviceName string, tchannelHostPort string, logger log.Logger) common.RPCFactory { +func newRPCFactory(serviceName string, tchannelHostPort string, logger log.Logger, cluster cluster.Metadata) common.RPCFactory { grpcPortResolver := grpcPortResolver{} grpcHostPort, err := grpcPortResolver.GetGRPCAddress("", tchannelHostPort) if err != nil { @@ -807,7 +807,9 @@ func newRPCFactory(serviceName string, tchannelHostPort string, logger log.Logge Unary: &versionMiddleware{}, }, // For integration tests to generate client out of the same outbound. - OutboundsBuilder: &singleTChannelOutbound{serviceName, tchannelHostPort}, + OutboundsBuilder: rpc.CombineOutbounds( + &singleTChannelOutbound{serviceName, tchannelHostPort}, + rpc.NewCrossDCOutbounds(cluster.GetAllClusterInfo(), rpc.NewDNSPeerChooserFactory(0, logger))), }) } diff --git a/host/service.go b/host/service.go index ef32c58c201..ab533a403ff 100644 --- a/host/service.go +++ b/host/service.go @@ -187,7 +187,7 @@ func (h *serviceImpl) Start() { h.clientBean, err = client.NewClientBean( client.NewRPCClientFactory(h.rpcFactory, h.membershipMonitor, h.metricsClient, h.dynamicCollection, h.numberOfHistoryShards, h.logger), - h.dispatcherProvider, + h.rpcFactory.GetDispatcher(), h.clusterMetadata, ) if err != nil {