Skip to content

Commit

Permalink
Managed failover workflow improvements (cadence-workflow#4491)
Browse files Browse the repository at this point in the history
* Managed failover workflow improvements
  • Loading branch information
yux0 authored Sep 23, 2021
1 parent 607893d commit 2679a9c
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 22 deletions.
67 changes: 61 additions & 6 deletions service/worker/failovermanager/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"

"go.uber.org/cadence"
"go.uber.org/cadence/activity"
"go.uber.org/cadence/workflow"
"go.uber.org/zap"

"github.com/uber/cadence/client/frontend"
"github.com/uber/cadence/common"
Expand Down Expand Up @@ -102,6 +104,8 @@ type (
Domains []string
// DrillWaitTime defines the wait time of a failover drill
DrillWaitTime time.Duration
// GracefulFailoverTimeoutInSeconds
GracefulFailoverTimeoutInSeconds *int32
}

// FailoverResult is workflow result
Expand All @@ -121,8 +125,9 @@ type (

// FailoverActivityParams params for activity
FailoverActivityParams struct {
Domains []string
TargetCluster string
Domains []string
TargetCluster string
GracefulFailoverTimeoutInSeconds *int32
}

// FailoverActivityResult result for failover activity
Expand Down Expand Up @@ -254,8 +259,9 @@ func failoverDomainsByBatch(
pauseSignalHandler()

failoverActivityParams := &FailoverActivityParams{
Domains: domains[i*batchSize : common.MinInt((i+1)*batchSize, totalNumOfDomains)],
TargetCluster: targetCluster,
Domains: domains[i*batchSize : common.MinInt((i+1)*batchSize, totalNumOfDomains)],
TargetCluster: targetCluster,
GracefulFailoverTimeoutInSeconds: params.GracefulFailoverTimeoutInSeconds,
}
var actResult FailoverActivityResult
err := workflow.ExecuteActivity(ao, FailoverActivity, failoverActivityParams).Get(ctx, &actResult)
Expand Down Expand Up @@ -390,6 +396,12 @@ func getClient(ctx context.Context) frontend.Client {
return feClient
}

func getRemoteClient(ctx context.Context, clusterName string) frontend.Client {
manager := ctx.Value(failoverManagerContextKey).(*FailoverManager)
feClient := manager.clientBean.GetRemoteFrontendClient(clusterName)
return feClient
}

func getAllDomains(ctx context.Context, targetDomains []string) ([]*types.DescribeDomainResponse, error) {
feClient := getClient(ctx)
var res []*types.DescribeDomainResponse
Expand Down Expand Up @@ -432,16 +444,28 @@ func getAllDomains(ctx context.Context, targetDomains []string) ([]*types.Descri

// FailoverActivity activity def
func FailoverActivity(ctx context.Context, params *FailoverActivityParams) (*FailoverActivityResult, error) {
feClient := getClient(ctx)

logger := activity.GetLogger(ctx)
frontendClient := getClient(ctx)
domains := params.Domains
var successDomains []string
var failedDomains []string
for _, domain := range domains {
// Check if poller exist
if err := validateTaskListPollerInfo(ctx, params.TargetCluster, domain); err != nil {
logger.Error("Failed to validate task list poller info", zap.Error(err))
failedDomains = append(failedDomains, domain)
continue
}
updateRequest := &types.UpdateDomainRequest{
Name: domain,
ActiveClusterName: common.StringPtr(params.TargetCluster),
}
_, err := feClient.UpdateDomain(ctx, updateRequest)
if params.GracefulFailoverTimeoutInSeconds != nil {
updateRequest.FailoverTimeoutInSeconds = params.GracefulFailoverTimeoutInSeconds
}

_, err := frontendClient.UpdateDomain(ctx, updateRequest)
if err != nil {
failedDomains = append(failedDomains, domain)
} else {
Expand All @@ -461,3 +485,34 @@ func cleanupChannel(channel workflow.Channel) {
}
}
}

func validateTaskListPollerInfo(ctx context.Context, targetCluster string, domain string) error {
remoteFrontendClient := getRemoteClient(ctx, targetCluster)
frontendClient := getClient(ctx)
localTaskListResponse, err := frontendClient.GetTaskListsByDomain(ctx, &types.GetTaskListsByDomainRequest{Domain: domain})
if err != nil {
return fmt.Errorf("failed to get task list for domain %s", domain)
}

remoteTaskListRepsonse, err := remoteFrontendClient.GetTaskListsByDomain(ctx, &types.GetTaskListsByDomainRequest{Domain: domain})
if err != nil {
return fmt.Errorf("failed to get task list for domain %s", domain)
}
for name, tl := range localTaskListResponse.GetDecisionTaskListMap() {
if len(tl.GetPollers()) != 0 {
remoteTaskList, ok := remoteTaskListRepsonse.GetDecisionTaskListMap()[name]
if !ok || len(remoteTaskList.GetPollers()) == 0 {
return fmt.Errorf("received zero poller in decision task list %s with domain %s", name, domain)
}
}
}
for name, tl := range localTaskListResponse.GetActivityTaskListMap() {
if len(tl.GetPollers()) != 0 {
remoteTaskList, ok := remoteTaskListRepsonse.GetActivityTaskListMap()[name]
if !ok || len(remoteTaskList.GetPollers()) == 0 {
return fmt.Errorf("received zero poller in decision task list %s with domain %s", name, domain)
}
}
}
return nil
}
127 changes: 126 additions & 1 deletion service/worker/failovermanager/workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,30 @@ func (s *failoverWorkflowTestSuite) TestGetDomainsActivity_WithTargetDomains() {
s.Equal([]string{"d1"}, result) // d3 filtered out because not managed
}

func (s *failoverWorkflowTestSuite) TestFailoverActivity() {
func (s *failoverWorkflowTestSuite) TestFailoverActivity_ForceFailover_Success() {
env, mockResource, controller := s.prepareTestActivityEnv()
defer controller.Finish()
defer mockResource.Finish(s.T())

domains := []string{"d1", "d2"}
describeTaskListResp := &types.DescribeTaskListResponse{Pollers: []*types.PollerInfo{
{
Identity: "test",
},
}}
taskListMap := map[string]*types.DescribeTaskListResponse{
"tl": describeTaskListResp,
}

mockResource.FrontendClient.EXPECT().UpdateDomain(gomock.Any(), gomock.Any()).Return(nil, nil).Times(len(domains))
mockResource.FrontendClient.EXPECT().GetTaskListsByDomain(gomock.Any(), gomock.Any()).Return(&types.GetTaskListsByDomainResponse{
DecisionTaskListMap: taskListMap,
ActivityTaskListMap: taskListMap,
}, nil).Times(len(domains))
mockResource.RemoteFrontendClient.EXPECT().GetTaskListsByDomain(gomock.Any(), gomock.Any()).Return(&types.GetTaskListsByDomainResponse{
DecisionTaskListMap: taskListMap,
ActivityTaskListMap: taskListMap,
}, nil).Times(len(domains))

params := &FailoverActivityParams{
Domains: domains,
Expand All @@ -434,6 +451,54 @@ func (s *failoverWorkflowTestSuite) TestFailoverActivity() {
s.Equal(domains, result.SuccessDomains)
}

func (s *failoverWorkflowTestSuite) TestFailoverActivity_GracefulFailover_Success() {
env, mockResource, controller := s.prepareTestActivityEnv()
defer controller.Finish()
defer mockResource.Finish(s.T())

domains := []string{"d1", "d2"}
describeTaskListResp := &types.DescribeTaskListResponse{Pollers: []*types.PollerInfo{
{
Identity: "test",
},
}}
taskListMap := map[string]*types.DescribeTaskListResponse{
"tl": describeTaskListResp,
}
params := &FailoverActivityParams{
Domains: domains,
TargetCluster: "c2",
GracefulFailoverTimeoutInSeconds: common.Int32Ptr(int32(10)),
}

updateRequest1 := &types.UpdateDomainRequest{
Name: "d1",
ActiveClusterName: common.StringPtr("c2"),
FailoverTimeoutInSeconds: params.GracefulFailoverTimeoutInSeconds,
}
updateRequest2 := &types.UpdateDomainRequest{
Name: "d2",
ActiveClusterName: common.StringPtr("c2"),
FailoverTimeoutInSeconds: params.GracefulFailoverTimeoutInSeconds,
}
mockResource.FrontendClient.EXPECT().UpdateDomain(gomock.Any(), updateRequest1).Return(nil, nil).Times(1)
mockResource.FrontendClient.EXPECT().UpdateDomain(gomock.Any(), updateRequest2).Return(nil, nil).Times(1)
mockResource.FrontendClient.EXPECT().GetTaskListsByDomain(gomock.Any(), gomock.Any()).Return(&types.GetTaskListsByDomainResponse{
DecisionTaskListMap: taskListMap,
ActivityTaskListMap: taskListMap,
}, nil).Times(len(domains))
mockResource.RemoteFrontendClient.EXPECT().GetTaskListsByDomain(gomock.Any(), gomock.Any()).Return(&types.GetTaskListsByDomainResponse{
DecisionTaskListMap: taskListMap,
ActivityTaskListMap: taskListMap,
}, nil).Times(len(domains))

actResult, err := env.ExecuteActivity(failoverActivityName, params)
s.NoError(err)
var result FailoverActivityResult
s.NoError(actResult.Get(&result))
s.Equal(domains, result.SuccessDomains)
}

func (s *failoverWorkflowTestSuite) TestFailoverActivity_Error() {
env, mockResource, controller := s.prepareTestActivityEnv()
defer controller.Finish()
Expand All @@ -449,8 +514,25 @@ func (s *failoverWorkflowTestSuite) TestFailoverActivity_Error() {
Name: "d2",
ActiveClusterName: common.StringPtr(targetCluster),
}
describeTaskListResp := &types.DescribeTaskListResponse{Pollers: []*types.PollerInfo{
{
Identity: "test",
},
}}
taskListMap := map[string]*types.DescribeTaskListResponse{
"tl": describeTaskListResp,
}

mockResource.FrontendClient.EXPECT().UpdateDomain(gomock.Any(), updateRequest1).Return(nil, nil)
mockResource.FrontendClient.EXPECT().UpdateDomain(gomock.Any(), updateRequest2).Return(nil, errors.New("mockErr"))
mockResource.FrontendClient.EXPECT().GetTaskListsByDomain(gomock.Any(), gomock.Any()).Return(&types.GetTaskListsByDomainResponse{
DecisionTaskListMap: taskListMap,
ActivityTaskListMap: taskListMap,
}, nil).Times(len(domains))
mockResource.RemoteFrontendClient.EXPECT().GetTaskListsByDomain(gomock.Any(), gomock.Any()).Return(&types.GetTaskListsByDomainResponse{
DecisionTaskListMap: taskListMap,
ActivityTaskListMap: taskListMap,
}, nil).Times(len(domains))

params := &FailoverActivityParams{
Domains: domains,
Expand All @@ -465,6 +547,49 @@ func (s *failoverWorkflowTestSuite) TestFailoverActivity_Error() {
s.Equal([]string{"d2"}, result.FailedDomains)
}

func (s *failoverWorkflowTestSuite) TestFailoverActivity_NoPoller_Error() {
env, mockResource, controller := s.prepareTestActivityEnv()
defer controller.Finish()
defer mockResource.Finish(s.T())

domains := []string{"d1", "d2"}
targetCluster := "c2"
describeTaskListResp1 := &types.DescribeTaskListResponse{Pollers: []*types.PollerInfo{
{
Identity: "test",
},
}}
taskListMap1 := map[string]*types.DescribeTaskListResponse{
"tl": describeTaskListResp1,
}
describeTaskListResp2 := &types.DescribeTaskListResponse{Pollers: []*types.PollerInfo{}}
taskListMap2 := map[string]*types.DescribeTaskListResponse{
"tl": describeTaskListResp2,
}

mockResource.FrontendClient.EXPECT().UpdateDomain(gomock.Any(), gomock.Any()).Times(0)
mockResource.FrontendClient.EXPECT().GetTaskListsByDomain(gomock.Any(), gomock.Any()).Return(&types.GetTaskListsByDomainResponse{
DecisionTaskListMap: taskListMap1,
ActivityTaskListMap: taskListMap1,
}, nil).Times(len(domains))
mockResource.RemoteFrontendClient.EXPECT().GetTaskListsByDomain(gomock.Any(), gomock.Any()).Return(&types.GetTaskListsByDomainResponse{
DecisionTaskListMap: taskListMap2,
ActivityTaskListMap: taskListMap2,
}, nil).Times(len(domains))

params := &FailoverActivityParams{
Domains: domains,
TargetCluster: targetCluster,
}

actResult, err := env.ExecuteActivity(failoverActivityName, params)
s.NoError(err)
var result FailoverActivityResult
s.NoError(actResult.Get(&result))
s.Equal(0, len(result.SuccessDomains))
s.Equal([]string{"d1", "d2"}, result.FailedDomains)
}

func (s *failoverWorkflowTestSuite) TestGetOperator() {
operator := "testOperator"
s.workflowEnv.SetMemoOnStart(map[string]interface{}{
Expand Down
12 changes: 10 additions & 2 deletions tools/cli/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -1074,8 +1074,12 @@ func newAdminFailoverCommands() []cli.Command {
},
cli.IntFlag{
Name: FlagFailoverTimeoutWithAlias,
Usage: "Optional graceful failover timeout in seconds. If this field is define, the failover will use graceful failover.",
},
cli.IntFlag{
Name: FlagExecutionTimeoutWithAlias,
Usage: "Optional Failover workflow timeout in seconds",
Value: defaultFailoverTimeoutInSeconds,
Value: defaultFailoverWorkflowTimeoutInSeconds,
},
cli.IntFlag{
Name: FlagFailoverWaitTimeWithAlias,
Expand Down Expand Up @@ -1200,8 +1204,12 @@ func newAdminFailoverCommands() []cli.Command {
},
cli.IntFlag{
Name: FlagFailoverTimeoutWithAlias,
Usage: "Optional graceful failover timeout in seconds. If this field is define, the failover will use graceful failover.",
},
cli.IntFlag{
Name: FlagExecutionTimeoutWithAlias,
Usage: "Optional Failover workflow timeout in seconds",
Value: defaultFailoverTimeoutInSeconds,
Value: defaultFailoverWorkflowTimeoutInSeconds,
},
cli.IntFlag{
Name: FlagFailoverWaitTimeWithAlias,
Expand Down
Loading

0 comments on commit 2679a9c

Please sign in to comment.