diff --git a/client/history/peerResolver.go b/client/history/peerResolver.go index 9ccbc7af811..8a542e42cc7 100644 --- a/client/history/peerResolver.go +++ b/client/history/peerResolver.go @@ -67,9 +67,10 @@ func (pr PeerResolver) FromShardID(shardID int) (string, error) { shardIDString := string(rune(shardID)) host, err := pr.resolver.Lookup(service.History, shardIDString) if err != nil { - return "", err + return "", common.ToServiceTransientError(err) } - return host.GetNamedAddress(pr.namedPort) + peer, err := host.GetNamedAddress(pr.namedPort) + return peer, common.ToServiceTransientError(err) } // FromHostAddress resolves the final history peer responsible for the given host address. @@ -77,22 +78,23 @@ func (pr PeerResolver) FromShardID(shardID int) (string, error) { func (pr PeerResolver) FromHostAddress(hostAddress string) (string, error) { host, err := pr.resolver.LookupByAddress(service.History, hostAddress) if err != nil { - return "", err + return "", common.ToServiceTransientError(err) } - return host.GetNamedAddress(pr.namedPort) + peer, err := host.GetNamedAddress(pr.namedPort) + return peer, common.ToServiceTransientError(err) } // GetAllPeers returns all history service peers in the cluster ring. func (pr PeerResolver) GetAllPeers() ([]string, error) { hosts, err := pr.resolver.Members(service.History) if err != nil { - return nil, err + return nil, common.ToServiceTransientError(err) } peers := make([]string, 0, len(hosts)) for _, host := range hosts { peer, err := host.GetNamedAddress(pr.namedPort) if err != nil { - return nil, err + return nil, common.ToServiceTransientError(err) } peers = append(peers, peer) } diff --git a/client/matching/peerResolver.go b/client/matching/peerResolver.go index 97be8ab85d9..60251cf59d1 100644 --- a/client/matching/peerResolver.go +++ b/client/matching/peerResolver.go @@ -21,6 +21,7 @@ package matching import ( + "github.com/uber/cadence/common" "github.com/uber/cadence/common/membership" "github.com/uber/cadence/common/service" ) @@ -47,23 +48,24 @@ func NewPeerResolver(membership membership.Resolver, namedPort string) PeerResol func (pr PeerResolver) FromTaskList(taskListName string) (string, error) { host, err := pr.resolver.Lookup(service.Matching, taskListName) if err != nil { - return "", err + return "", common.ToServiceTransientError(err) } - return pr.FromHostAddress(host.GetAddress()) + peer, err := host.GetNamedAddress(pr.namedPort) + return peer, common.ToServiceTransientError(err) } // GetAllPeers returns all matching service peers in the cluster ring. func (pr PeerResolver) GetAllPeers() ([]string, error) { hosts, err := pr.resolver.Members(service.Matching) if err != nil { - return nil, err + return nil, common.ToServiceTransientError(err) } peers := make([]string, 0, len(hosts)) for _, host := range hosts { peer, err := pr.FromHostAddress(host.GetAddress()) if err != nil { - return nil, err + return nil, common.ToServiceTransientError(err) } peers = append(peers, peer) } @@ -76,9 +78,9 @@ func (pr PeerResolver) GetAllPeers() ([]string, error) { func (pr PeerResolver) FromHostAddress(hostAddress string) (string, error) { host, err := pr.resolver.LookupByAddress(service.Matching, hostAddress) if err != nil { - return "", err + return "", common.ToServiceTransientError(err) } - return host.GetNamedAddress(pr.namedPort) - + peer, err := host.GetNamedAddress(pr.namedPort) + return peer, common.ToServiceTransientError(err) } diff --git a/common/util.go b/common/util.go index 5bda8118cd2..b6fc7b7b473 100644 --- a/common/util.go +++ b/common/util.go @@ -263,6 +263,14 @@ func CheckDecisionResultLimit( return nil } +// ToServiceTransientError converts an error to ServiceTransientError +func ToServiceTransientError(err error) error { + if err == nil || IsServiceTransientError(err) { + return err + } + return yarpcerrors.Newf(yarpcerrors.CodeUnavailable, err.Error()) +} + // IsServiceTransientError checks if the error is a transient error. func IsServiceTransientError(err error) bool { switch err.(type) { diff --git a/common/util_test.go b/common/util_test.go index 21eb982b231..93dfb8d04c9 100644 --- a/common/util_test.go +++ b/common/util_test.go @@ -32,6 +32,7 @@ import ( "time" "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/yarpc/yarpcerrors" @@ -341,3 +342,20 @@ func TestConvertErrToGetTaskFailedCause(t *testing.T) { require.Equal(t, tc.expectedFailedCause, ConvertErrToGetTaskFailedCause(tc.err)) } } + +func TestToServiceTransientError(t *testing.T) { + t.Run("it converts nil", func(t *testing.T) { + assert.NoError(t, ToServiceTransientError(nil)) + }) + + t.Run("it keeps transient errors", func(t *testing.T) { + err := &types.InternalServiceError{} + assert.Equal(t, err, ToServiceTransientError(err)) + assert.True(t, IsServiceTransientError(ToServiceTransientError(err))) + }) + + t.Run("it converts errors to transient errors", func(t *testing.T) { + err := fmt.Errorf("error") + assert.True(t, IsServiceTransientError(ToServiceTransientError(err))) + }) +}