Skip to content

Commit

Permalink
Added cross DC outbound builder (cadence-workflow#4552)
Browse files Browse the repository at this point in the history
  • Loading branch information
vytautas-karpavicius authored Oct 11, 2021
1 parent 7a6b851 commit 9ff3eb3
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 38 deletions.
32 changes: 2 additions & 30 deletions client/clientBean.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,12 @@ import (
"sync/atomic"

"go.uber.org/yarpc"
"go.uber.org/yarpc/transport/grpc"
"go.uber.org/yarpc/transport/tchannel"

"github.com/uber/cadence/client/admin"
"github.com/uber/cadence/client/frontend"
"github.com/uber/cadence/client/history"
"github.com/uber/cadence/client/matching"
"github.com/uber/cadence/common/authorization"
"github.com/uber/cadence/common/cluster"
"github.com/uber/cadence/common/rpc"
)

type (
Expand Down Expand Up @@ -67,7 +63,7 @@ type (
)

// NewClientBean provides a collection of clients
func NewClientBean(factory Factory, dispatcherProvider rpc.DispatcherProvider, clusterMetadata cluster.Metadata) (Bean, error) {
func NewClientBean(factory Factory, dispatcher *yarpc.Dispatcher, clusterMetadata cluster.Metadata) (Bean, error) {

historyClient, err := factory.NewHistoryClient()
if err != nil {
Expand All @@ -81,31 +77,7 @@ func NewClientBean(factory Factory, dispatcherProvider rpc.DispatcherProvider, c
continue
}

var dispatcherOptions *rpc.DispatcherOptions
if info.AuthorizationProvider.Enable {
authProvider, err := authorization.GetAuthProviderClient(info.AuthorizationProvider.PrivateKey)
if err != nil {
return nil, err
}
dispatcherOptions = &rpc.DispatcherOptions{
AuthProvider: authProvider,
}
}

var dispatcher *yarpc.Dispatcher
var err error
switch info.RPCTransport {
case tchannel.TransportName:
dispatcher, err = dispatcherProvider.GetTChannel(info.RPCName, info.RPCAddress, dispatcherOptions)
case grpc.TransportName:
dispatcher, err = dispatcherProvider.GetGRPC(info.RPCName, info.RPCAddress, dispatcherOptions)
}

if err != nil {
return nil, err
}

clientConfig := dispatcher.ClientConfig(info.RPCName)
clientConfig := dispatcher.ClientConfig(clusterName)

adminClient, err := factory.NewAdminClientWithTimeoutAndConfig(
clientConfig,
Expand Down
4 changes: 4 additions & 0 deletions cmd/server/cadence/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ func (s *server) startService() common.Daemon {
if err != nil {
log.Fatalf("error creating rpc factory params: %v", err)
}
rpcParams.OutboundsBuilder = rpc.CombineOutbounds(
rpcParams.OutboundsBuilder,
rpc.NewCrossDCOutbounds(clusterGroupMetadata.ClusterGroup, rpc.NewDNSPeerChooserFactory(s.cfg.PublicClient.RefreshInterval, params.Logger)),
)
params.RPCFactory = rpc.NewFactory(params.Logger, rpcParams)
dispatcher := params.RPCFactory.GetDispatcher()

Expand Down
2 changes: 1 addition & 1 deletion common/resource/resourceImpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func New(
numShards,
logger,
),
params.DispatcherProvider,
params.RPCFactory.GetDispatcher(),
params.ClusterMetadata,
)
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions common/rpc/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,12 @@ func (m *inboundMetricsMiddleware) Handle(ctx context.Context, req *transport.Re
)
return h.Handle(ctx, req, resw)
}

type overrideCallerMiddleware struct {
caller string
}

func (m *overrideCallerMiddleware) Call(ctx context.Context, request *transport.Request, out transport.UnaryOutbound) (*transport.Response, error) {
request.Caller = m.caller
return out.Call(ctx, request)
}
8 changes: 8 additions & 0 deletions common/rpc/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ func TestInboundMetricsMiddleware(t *testing.T) {
})
}

func TestOverrideCallerMiddleware(t *testing.T) {
m := overrideCallerMiddleware{"x-caller"}
_, err := m.Call(context.Background(), &transport.Request{Caller: "service"}, &fakeOutbound{verify: func(r *transport.Request) {
assert.Equal(t, "x-caller", r.Caller)
}})
assert.NoError(t, err)
}

type fakeHandler struct {
ctx context.Context
}
Expand Down
55 changes: 55 additions & 0 deletions common/rpc/outbounds.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"go.uber.org/multierr"
"go.uber.org/yarpc"
"go.uber.org/yarpc/api/middleware"
"go.uber.org/yarpc/api/transport"
"go.uber.org/yarpc/transport/grpc"
"go.uber.org/yarpc/transport/tchannel"
)
Expand Down Expand Up @@ -106,3 +107,57 @@ func (b publicClientOutbound) Build(_ *grpc.Transport, tchannel *tchannel.Transp
},
}, nil
}

type crossDCOutbounds struct {
clusterGroup map[string]config.ClusterInformation
pcf PeerChooserFactory
}

func NewCrossDCOutbounds(clusterGroup map[string]config.ClusterInformation, pcf PeerChooserFactory) OutboundsBuilder {
return crossDCOutbounds{clusterGroup, pcf}
}

func (b crossDCOutbounds) Build(grpcTransport *grpc.Transport, tchannelTransport *tchannel.Transport) (yarpc.Outbounds, error) {
outbounds := yarpc.Outbounds{}
for clusterName, clusterInfo := range b.clusterGroup {
if !clusterInfo.Enabled {
continue
}

var outbound transport.UnaryOutbound
switch clusterInfo.RPCTransport {
case tchannel.TransportName:
peerChooser, err := b.pcf.CreatePeerChooser(tchannelTransport, clusterInfo.RPCAddress)
if err != nil {
return nil, err
}
outbound = tchannelTransport.NewOutbound(peerChooser)
case grpc.TransportName:
peerChooser, err := b.pcf.CreatePeerChooser(grpcTransport, clusterInfo.RPCAddress)
if err != nil {
return nil, err
}
outbound = grpcTransport.NewOutbound(peerChooser)
default:
return nil, fmt.Errorf("unknown cross DC transport type: %s", clusterInfo.RPCTransport)
}

var authMiddleware middleware.UnaryOutbound
if clusterInfo.AuthorizationProvider.Enable {
authProvider, err := authorization.GetAuthProviderClient(clusterInfo.AuthorizationProvider.PrivateKey)
if err != nil {
return nil, fmt.Errorf("create AuthProvider: %v", err)
}
authMiddleware = &authOutboundMiddleware{authProvider}
}

outbounds[clusterName] = transport.Outbounds{
ServiceName: clusterInfo.RPCName,
Unary: middleware.ApplyUnaryOutbound(outbound, yarpc.UnaryOutboundMiddleware(
authMiddleware,
&overrideCallerMiddleware{crossDCCaller},
)),
}
}
return outbounds, nil
}
38 changes: 38 additions & 0 deletions common/rpc/outbounds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/yarpc"
"go.uber.org/yarpc/api/peer"
"go.uber.org/yarpc/peer/direct"
"go.uber.org/yarpc/transport/grpc"
"go.uber.org/yarpc/transport/tchannel"
)
Expand Down Expand Up @@ -119,6 +121,36 @@ func TestPublicClientOutbound(t *testing.T) {
assert.NotNil(t, outbounds[OutboundPublicClient].Unary)
}

func TestCrossDCOutbounds(t *testing.T) {
grpc := &grpc.Transport{}
tchannel := &tchannel.Transport{}

clusterGroup := map[string]config.ClusterInformation{
"cluster-A": {Enabled: true, RPCName: "cadence-frontend", RPCTransport: "invalid"},
}
_, err := NewCrossDCOutbounds(clusterGroup, &fakePeerChooserFactory{}).Build(grpc, tchannel)
assert.EqualError(t, err, "unknown cross DC transport type: invalid")

clusterGroup = map[string]config.ClusterInformation{
"cluster-A": {Enabled: true, RPCName: "cadence-frontend", RPCTransport: "grpc", AuthorizationProvider: config.AuthorizationProvider{Enable: true, PrivateKey: "invalid path"}},
}
_, err = NewCrossDCOutbounds(clusterGroup, &fakePeerChooserFactory{}).Build(grpc, tchannel)
assert.EqualError(t, err, "create AuthProvider: invalid private key path invalid path")

clusterGroup = map[string]config.ClusterInformation{
"cluster-A": {Enabled: true, RPCName: "cadence-frontend", RPCAddress: "address-A", RPCTransport: "grpc", AuthorizationProvider: config.AuthorizationProvider{Enable: true, PrivateKey: tempFile(t, "key")}},
"cluster-B": {Enabled: true, RPCName: "cadence-frontend", RPCAddress: "address-B", RPCTransport: "tchannel"},
"cluster-C": {Enabled: false},
}
outbounds, err := NewCrossDCOutbounds(clusterGroup, &fakePeerChooserFactory{}).Build(grpc, tchannel)
assert.NoError(t, err)
assert.Equal(t, 2, len(outbounds))
assert.Equal(t, "cadence-frontend", outbounds["cluster-A"].ServiceName)
assert.Equal(t, "cadence-frontend", outbounds["cluster-B"].ServiceName)
assert.NotNil(t, outbounds["cluster-A"].Unary)
assert.NotNil(t, outbounds["cluster-B"].Unary)
}

func tempFile(t *testing.T, content string) string {
f, err := ioutil.TempFile("", "")
require.NoError(t, err)
Expand All @@ -140,3 +172,9 @@ type fakeOutboundBuilder struct {
func (b fakeOutboundBuilder) Build(*grpc.Transport, *tchannel.Transport) (yarpc.Outbounds, error) {
return b.outbounds, b.err
}

type fakePeerChooserFactory struct{}

func (f fakePeerChooserFactory) CreatePeerChooser(transport peer.Transport, address string) (peer.Chooser, error) {
return direct.New(direct.Configuration{}, transport)
}
14 changes: 8 additions & 6 deletions host/onebox.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ func (c *cadenceImpl) startFrontend(hosts map[string][]string, startWG *sync.Wai
params.ThrottledLogger = c.logger
params.UpdateLoggerWithServiceName(service.Frontend)
params.PProfInitializer = newPProfInitializerImpl(c.logger, c.FrontendPProfPort())
params.RPCFactory = newRPCFactory(service.Frontend, c.FrontendAddress(), c.logger)
params.RPCFactory = newRPCFactory(service.Frontend, c.FrontendAddress(), c.logger, c.clusterMetadata)
params.MetricScope = tally.NewTestScope(service.Frontend, make(map[string]string))
params.MembershipFactory = newMembershipFactory(params.Name, hosts)
params.ClusterMetadata = c.clusterMetadata
Expand Down Expand Up @@ -465,7 +465,7 @@ func (c *cadenceImpl) startHistory(
params.ThrottledLogger = c.logger
params.UpdateLoggerWithServiceName(service.History)
params.PProfInitializer = newPProfInitializerImpl(c.logger, pprofPorts[i])
params.RPCFactory = newRPCFactory(service.History, hostport, c.logger)
params.RPCFactory = newRPCFactory(service.History, hostport, c.logger, c.clusterMetadata)
params.MetricScope = tally.NewTestScope(service.History, make(map[string]string))
params.MembershipFactory = newMembershipFactory(params.Name, hosts)
params.ClusterMetadata = c.clusterMetadata
Expand Down Expand Up @@ -535,7 +535,7 @@ func (c *cadenceImpl) startMatching(hosts map[string][]string, startWG *sync.Wai
params.ThrottledLogger = c.logger
params.UpdateLoggerWithServiceName(service.Matching)
params.PProfInitializer = newPProfInitializerImpl(c.logger, c.MatchingPProfPort())
params.RPCFactory = newRPCFactory(service.Matching, c.MatchingServiceAddress(), c.logger)
params.RPCFactory = newRPCFactory(service.Matching, c.MatchingServiceAddress(), c.logger, c.clusterMetadata)
params.MetricScope = tally.NewTestScope(service.Matching, make(map[string]string))
params.MembershipFactory = newMembershipFactory(params.Name, hosts)
params.ClusterMetadata = c.clusterMetadata
Expand Down Expand Up @@ -578,7 +578,7 @@ func (c *cadenceImpl) startWorker(hosts map[string][]string, startWG *sync.WaitG
params.ThrottledLogger = c.logger
params.UpdateLoggerWithServiceName(service.Worker)
params.PProfInitializer = newPProfInitializerImpl(c.logger, c.WorkerPProfPort())
params.RPCFactory = newRPCFactory(service.Worker, c.WorkerServiceAddress(), c.logger)
params.RPCFactory = newRPCFactory(service.Worker, c.WorkerServiceAddress(), c.logger, c.clusterMetadata)
params.MetricScope = tally.NewTestScope(service.Worker, make(map[string]string))
params.MembershipFactory = newMembershipFactory(params.Name, hosts)
params.ClusterMetadata = c.clusterMetadata
Expand Down Expand Up @@ -791,7 +791,7 @@ func newPProfInitializerImpl(logger log.Logger, port int) common.PProfInitialize
}
}

func newRPCFactory(serviceName string, tchannelHostPort string, logger log.Logger) common.RPCFactory {
func newRPCFactory(serviceName string, tchannelHostPort string, logger log.Logger, cluster cluster.Metadata) common.RPCFactory {
grpcPortResolver := grpcPortResolver{}
grpcHostPort, err := grpcPortResolver.GetGRPCAddress("", tchannelHostPort)
if err != nil {
Expand All @@ -807,7 +807,9 @@ func newRPCFactory(serviceName string, tchannelHostPort string, logger log.Logge
Unary: &versionMiddleware{},
},
// For integration tests to generate client out of the same outbound.
OutboundsBuilder: &singleTChannelOutbound{serviceName, tchannelHostPort},
OutboundsBuilder: rpc.CombineOutbounds(
&singleTChannelOutbound{serviceName, tchannelHostPort},
rpc.NewCrossDCOutbounds(cluster.GetAllClusterInfo(), rpc.NewDNSPeerChooserFactory(0, logger))),
})
}

Expand Down
2 changes: 1 addition & 1 deletion host/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (h *serviceImpl) Start() {

h.clientBean, err = client.NewClientBean(
client.NewRPCClientFactory(h.rpcFactory, h.membershipMonitor, h.metricsClient, h.dynamicCollection, h.numberOfHistoryShards, h.logger),
h.dispatcherProvider,
h.rpcFactory.GetDispatcher(),
h.clusterMetadata,
)
if err != nil {
Expand Down

0 comments on commit 9ff3eb3

Please sign in to comment.