diff --git a/common/persistence/data_manager_interfaces.go b/common/persistence/data_manager_interfaces.go index f1763d1d9c5..53d1f3497ad 100644 --- a/common/persistence/data_manager_interfaces.go +++ b/common/persistence/data_manager_interfaces.go @@ -2054,3 +2054,14 @@ func (e *WorkflowExecutionInfo) CopyPartitionConfig() map[string]string { } return partitionConfig } + +func (p *TaskListPartitionConfig) ToInternalType() *types.TaskListPartitionConfig { + if p == nil { + return nil + } + return &types.TaskListPartitionConfig{ + Version: p.Version, + NumReadPartitions: int32(p.NumReadPartitions), + NumWritePartitions: int32(p.NumWritePartitions), + } +} diff --git a/common/persistence/data_manager_interfaces_test.go b/common/persistence/data_manager_interfaces_test.go index 67b1176afe0..166fc090a58 100644 --- a/common/persistence/data_manager_interfaces_test.go +++ b/common/persistence/data_manager_interfaces_test.go @@ -509,3 +509,41 @@ func TestCopyPartitionConfig(t *testing.T) { }) } } + +func TestTaskListPartitionConfigToInternalType(t *testing.T) { + testCases := []struct { + name string + input *TaskListPartitionConfig + expect *types.TaskListPartitionConfig + }{ + { + name: "nil case", + input: nil, + expect: nil, + }, + { + name: "empty case", + input: &TaskListPartitionConfig{}, + expect: &types.TaskListPartitionConfig{}, + }, + { + name: "normal case", + input: &TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 3, + }, + expect: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 3, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expect, tc.input.ToInternalType()) + }) + } +} diff --git a/service/history/task/transfer_active_task_executor_test.go b/service/history/task/transfer_active_task_executor_test.go index fb46dd1cea1..ec480e94260 100644 --- a/service/history/task/transfer_active_task_executor_test.go +++ b/service/history/task/transfer_active_task_executor_test.go @@ -596,8 +596,8 @@ func (s *transferActiveTaskExecutorSuite) TestProcessDecisionTask_StickyWorkerUn } gomock.InOrder( - s.mockMatchingClient.EXPECT().AddDecisionTask(gomock.Any(), addDecisionTaskRequest).Return(&types.StickyWorkerUnavailableError{}).Times(1), - s.mockMatchingClient.EXPECT().AddDecisionTask(gomock.Any(), gomock.Eq(&modifiedRequest)).Return(nil).Times(1), + s.mockMatchingClient.EXPECT().AddDecisionTask(gomock.Any(), addDecisionTaskRequest).Return(nil, &types.StickyWorkerUnavailableError{}).Times(1), + s.mockMatchingClient.EXPECT().AddDecisionTask(gomock.Any(), gomock.Eq(&modifiedRequest)).Return(&types.AddDecisionTaskResponse{}, nil).Times(1), ) err = s.transferActiveTaskExecutor.Execute(transferTask, true) diff --git a/service/matching/handler/engine.go b/service/matching/handler/engine.go index 6c1b9362f99..e39e23c9e55 100644 --- a/service/matching/handler/engine.go +++ b/service/matching/handler/engine.go @@ -321,7 +321,8 @@ func (e *matchingEngineImpl) removeTaskListManager(tlMgr tasklist.Manager) { func (e *matchingEngineImpl) AddDecisionTask( hCtx *handlerContext, request *types.AddDecisionTaskRequest, -) (bool, error) { +) (*types.AddDecisionTaskResponse, error) { + startT := time.Now() domainID := request.GetDomainUUID() taskListName := request.GetTaskList().GetName() taskListKind := request.GetTaskList().Kind @@ -355,13 +356,13 @@ func (e *matchingEngineImpl) AddDecisionTask( taskListID, err := tasklist.NewIdentifier(domainID, taskListName, taskListType) if err != nil { - return false, err + return nil, err } // get the domainName domainName, err := e.domainCache.GetDomainName(domainID) if err != nil { - return false, err + return nil, err } // Only emit traffic metrics if the tasklist is not sticky and is not forwarded @@ -376,13 +377,13 @@ func (e *matchingEngineImpl) AddDecisionTask( tlMgr, err := e.getTaskListManager(taskListID, taskListKind) if err != nil { - return false, err + return nil, err } if taskListKind != nil && *taskListKind == types.TaskListKindSticky { // check if the sticky worker is still available, if not, fail this request early if !tlMgr.HasPollerAfter(e.timeSource.Now().Add(-_stickyPollerUnavailableWindow)) { - return false, _stickyPollerUnavailableError + return nil, _stickyPollerUnavailableError } } @@ -396,18 +397,28 @@ func (e *matchingEngineImpl) AddDecisionTask( PartitionConfig: request.GetPartitionConfig(), } - return tlMgr.AddTask(hCtx.Context, tasklist.AddTaskParams{ + syncMatched, err := tlMgr.AddTask(hCtx.Context, tasklist.AddTaskParams{ TaskInfo: taskInfo, Source: request.GetSource(), ForwardedFrom: request.GetForwardedFrom(), }) + if err != nil { + return nil, err + } + if syncMatched { + hCtx.scope.RecordTimer(metrics.SyncMatchLatencyPerTaskList, time.Since(startT)) + } + return &types.AddDecisionTaskResponse{ + PartitionConfig: tlMgr.TaskListPartitionConfig(), + }, nil } // AddActivityTask either delivers task directly to waiting poller or save it into task list persistence. func (e *matchingEngineImpl) AddActivityTask( hCtx *handlerContext, request *types.AddActivityTaskRequest, -) (bool, error) { +) (*types.AddActivityTaskResponse, error) { + startT := time.Now() domainID := request.GetDomainUUID() taskListName := request.GetTaskList().GetName() taskListKind := request.GetTaskList().Kind @@ -427,13 +438,13 @@ func (e *matchingEngineImpl) AddActivityTask( taskListID, err := tasklist.NewIdentifier(domainID, taskListName, taskListType) if err != nil { - return false, err + return nil, err } // get the domainName domainName, err := e.domainCache.GetDomainName(domainID) if err != nil { - return false, err + return nil, err } // Only emit traffic metrics if the tasklist is not sticky and is not forwarded @@ -448,7 +459,7 @@ func (e *matchingEngineImpl) AddActivityTask( tlMgr, err := e.getTaskListManager(taskListID, taskListKind) if err != nil { - return false, err + return nil, err } taskInfo := &persistence.TaskInfo{ @@ -461,12 +472,21 @@ func (e *matchingEngineImpl) AddActivityTask( PartitionConfig: request.GetPartitionConfig(), } - return tlMgr.AddTask(hCtx.Context, tasklist.AddTaskParams{ + syncMatched, err := tlMgr.AddTask(hCtx.Context, tasklist.AddTaskParams{ TaskInfo: taskInfo, Source: request.GetSource(), ForwardedFrom: request.GetForwardedFrom(), ActivityTaskDispatchInfo: request.ActivityTaskDispatchInfo, }) + if err != nil { + return nil, err + } + if syncMatched { + hCtx.scope.RecordTimer(metrics.SyncMatchLatencyPerTaskList, time.Since(startT)) + } + return &types.AddActivityTaskResponse{ + PartitionConfig: tlMgr.TaskListPartitionConfig(), + }, nil } // PollForDecisionTask tries to get the decision task using exponential backoff. @@ -512,7 +532,11 @@ pollLoop: pollerCtx := tasklist.ContextWithPollerID(hCtx.Context, pollerID) pollerCtx = tasklist.ContextWithIdentity(pollerCtx, request.GetIdentity()) pollerCtx = tasklist.ContextWithIsolationGroup(pollerCtx, req.GetIsolationGroup()) - task, err := e.getTask(pollerCtx, taskListID, nil, taskListKind) + tlMgr, err := e.getTaskListManager(taskListID, taskListKind) + if err != nil { + return nil, fmt.Errorf("couldn't load tasklist namanger: %w", err) + } + task, err := tlMgr.GetTask(pollerCtx, nil) if err != nil { // TODO: Is empty poll the best reply for errPumpClosed? if errors.Is(err, tasklist.ErrNoTasks) || errors.Is(err, errPumpClosed) { @@ -534,7 +558,9 @@ pollLoop: "RequestForwardedFrom": req.GetForwardedFrom(), }, }) - return emptyPollForDecisionTaskResponse, nil + return &types.MatchingPollForDecisionTaskResponse{ + PartitionConfig: tlMgr.TaskListPartitionConfig(), + }, nil } return nil, fmt.Errorf("couldn't get task: %w", err) } @@ -568,7 +594,9 @@ pollLoop: if err != nil { // will notify query client that the query task failed e.deliverQueryResult(task.Query.TaskID, &queryResult{internalError: err}) //nolint:errcheck - return emptyPollForDecisionTaskResponse, nil + return &types.MatchingPollForDecisionTaskResponse{ + PartitionConfig: tlMgr.TaskListPartitionConfig(), + }, nil } isStickyEnabled := false @@ -585,7 +613,7 @@ pollLoop: BranchToken: mutableStateResp.CurrentBranchToken, HistorySize: mutableStateResp.HistorySize, } - return e.createPollForDecisionTaskResponse(task, resp, hCtx.scope), nil + return e.createPollForDecisionTaskResponse(task, resp, hCtx.scope, tlMgr.TaskListPartitionConfig()), nil } e.emitTaskIsolationMetrics(hCtx.scope, task.Event.PartitionConfig, req.GetIsolationGroup()) @@ -641,7 +669,7 @@ pollLoop: }, }) - return e.createPollForDecisionTaskResponse(task, resp, hCtx.scope), nil + return e.createPollForDecisionTaskResponse(task, resp, hCtx.scope, tlMgr.TaskListPartitionConfig()), nil } } @@ -683,11 +711,17 @@ pollLoop: pollerCtx = tasklist.ContextWithIdentity(pollerCtx, request.GetIdentity()) pollerCtx = tasklist.ContextWithIsolationGroup(pollerCtx, req.GetIsolationGroup()) taskListKind := request.TaskList.Kind - task, err := e.getTask(pollerCtx, taskListID, maxDispatch, taskListKind) + tlMgr, err := e.getTaskListManager(taskListID, taskListKind) + if err != nil { + return nil, fmt.Errorf("couldn't load tasklist namanger: %w", err) + } + task, err := tlMgr.GetTask(pollerCtx, maxDispatch) if err != nil { // TODO: Is empty poll the best reply for errPumpClosed? if errors.Is(err, tasklist.ErrNoTasks) || errors.Is(err, errPumpClosed) { - return emptyPollForActivityTaskResponse, nil + return &types.MatchingPollForActivityTaskResponse{ + PartitionConfig: tlMgr.TaskListPartitionConfig(), + }, nil } e.logger.Error("Received unexpected err while getting task", tag.WorkflowTaskListName(taskListName), @@ -705,7 +739,7 @@ pollLoop: e.emitTaskIsolationMetrics(hCtx.scope, task.Event.PartitionConfig, req.GetIsolationGroup()) if task.ActivityTaskDispatchInfo != nil { task.Finish(nil) - return e.createSyncMatchPollForActivityTaskResponse(task, task.ActivityTaskDispatchInfo), nil + return e.createSyncMatchPollForActivityTaskResponse(task, task.ActivityTaskDispatchInfo, tlMgr.TaskListPartitionConfig()), nil } resp, err := e.recordActivityTaskStarted(hCtx.Context, request, task) @@ -737,13 +771,14 @@ pollLoop: continue pollLoop } task.Finish(nil) - return e.createPollForActivityTaskResponse(task, resp, hCtx.scope), nil + return e.createPollForActivityTaskResponse(task, resp, hCtx.scope, tlMgr.TaskListPartitionConfig()), nil } } func (e *matchingEngineImpl) createSyncMatchPollForActivityTaskResponse( task *tasklist.InternalTask, activityTaskDispatchInfo *types.ActivityTaskDispatchInfo, + partitionConfig *types.TaskListPartitionConfig, ) *types.MatchingPollForActivityTaskResponse { scheduledEvent := activityTaskDispatchInfo.ScheduledEvent @@ -777,6 +812,7 @@ func (e *matchingEngineImpl) createSyncMatchPollForActivityTaskResponse( response.HeartbeatDetails = activityTaskDispatchInfo.HeartbeatDetails response.WorkflowType = activityTaskDispatchInfo.WorkflowType response.WorkflowDomain = activityTaskDispatchInfo.WorkflowDomain + response.PartitionConfig = partitionConfig return response } @@ -1006,15 +1042,6 @@ func (e *matchingEngineImpl) getAllPartitions( return partitionKeys, nil } -// Loads a task from persistence and wraps it in a task context -func (e *matchingEngineImpl) getTask(ctx context.Context, taskList *tasklist.Identifier, maxDispatchPerSecond *float64, taskListKind *types.TaskListKind) (*tasklist.InternalTask, error) { - tlMgr, err := e.getTaskListManager(taskList, taskListKind) - if err != nil { - return nil, fmt.Errorf("couldn't load tasklist namanger: %w", err) - } - return tlMgr.GetTask(ctx, maxDispatchPerSecond) -} - func (e *matchingEngineImpl) unloadTaskList(tlMgr tasklist.Manager) { id := tlMgr.TaskListID() e.taskListsLock.Lock() @@ -1033,6 +1060,7 @@ func (e *matchingEngineImpl) createPollForDecisionTaskResponse( task *tasklist.InternalTask, historyResponse *types.RecordDecisionTaskStartedResponse, scope metrics.Scope, + partitionConfig *types.TaskListPartitionConfig, ) *types.MatchingPollForDecisionTaskResponse { var token []byte @@ -1067,6 +1095,7 @@ func (e *matchingEngineImpl) createPollForDecisionTaskResponse( response.Query = task.Query.Request.QueryRequest.Query } response.BacklogCountHint = task.BacklogCountHint + response.PartitionConfig = partitionConfig return response } @@ -1075,6 +1104,7 @@ func (e *matchingEngineImpl) createPollForActivityTaskResponse( task *tasklist.InternalTask, historyResponse *types.RecordActivityTaskStartedResponse, scope metrics.Scope, + partitionConfig *types.TaskListPartitionConfig, ) *types.MatchingPollForActivityTaskResponse { scheduledEvent := historyResponse.ScheduledEvent @@ -1118,6 +1148,7 @@ func (e *matchingEngineImpl) createPollForActivityTaskResponse( response.HeartbeatDetails = historyResponse.HeartbeatDetails response.WorkflowType = historyResponse.WorkflowType response.WorkflowDomain = historyResponse.WorkflowDomain + response.PartitionConfig = partitionConfig return response } diff --git a/service/matching/handler/engine_integration_test.go b/service/matching/handler/engine_integration_test.go index da79e12291e..bd0d82aaff4 100644 --- a/service/matching/handler/engine_integration_test.go +++ b/service/matching/handler/engine_integration_test.go @@ -1067,12 +1067,14 @@ func (s *matchingEngineSuite) TestAddTaskAfterStartFailure() { s.NoError(err) s.EqualValues(1, s.taskManager.GetTaskCount(tlID)) - ctx, err := s.matchingEngine.getTask(context.Background(), tlID, nil, &tlKind) + tlMgr, err := s.matchingEngine.getTaskListManager(tlID, &tlKind) + s.NoError(err) + ctx, err := tlMgr.GetTask(context.Background(), nil) s.NoError(err) ctx.Finish(errors.New("test error")) s.EqualValues(1, s.taskManager.GetTaskCount(tlID)) - ctx2, err := s.matchingEngine.getTask(context.Background(), tlID, nil, &tlKind) + ctx2, err := tlMgr.GetTask(context.Background(), nil) s.NoError(err) s.NotEqual(ctx.Event.TaskID, ctx2.Event.TaskID) @@ -1514,9 +1516,13 @@ type addTaskRequest struct { PartitionConfig map[string]string } -func addTask(engine *matchingEngineImpl, hCtx *handlerContext, request *addTaskRequest) (bool, error) { +type addTaskResponse struct { + PartitionConfig *types.TaskListPartitionConfig +} + +func addTask(engine *matchingEngineImpl, hCtx *handlerContext, request *addTaskRequest) (*addTaskResponse, error) { if request.TaskType == persistence.TaskListTypeActivity { - return engine.AddActivityTask(hCtx, &types.AddActivityTaskRequest{ + resp, err := engine.AddActivityTask(hCtx, &types.AddActivityTaskRequest{ SourceDomainUUID: request.DomainUUID, DomainUUID: request.DomainUUID, Execution: request.Execution, @@ -1527,8 +1533,14 @@ func addTask(engine *matchingEngineImpl, hCtx *handlerContext, request *addTaskR ForwardedFrom: request.ForwardedFrom, PartitionConfig: request.PartitionConfig, }) + if err != nil { + return nil, err + } + return &addTaskResponse{ + PartitionConfig: resp.PartitionConfig, + }, nil } - return engine.AddDecisionTask(hCtx, &types.AddDecisionTaskRequest{ + resp, err := engine.AddDecisionTask(hCtx, &types.AddDecisionTaskRequest{ DomainUUID: request.DomainUUID, Execution: request.Execution, TaskList: request.TaskList, @@ -1538,6 +1550,12 @@ func addTask(engine *matchingEngineImpl, hCtx *handlerContext, request *addTaskR ForwardedFrom: request.ForwardedFrom, PartitionConfig: request.PartitionConfig, }) + if err != nil { + return nil, err + } + return &addTaskResponse{ + PartitionConfig: resp.PartitionConfig, + }, nil } type pollTaskRequest struct { diff --git a/service/matching/handler/handler.go b/service/matching/handler/handler.go index 557781de1e4..b0766e3a07b 100644 --- a/service/matching/handler/handler.go +++ b/service/matching/handler/handler.go @@ -23,7 +23,6 @@ package handler import ( "context" "sync" - "time" "github.com/uber/cadence/common" "github.com/uber/cadence/common/cache" @@ -128,7 +127,6 @@ func (h *handlerImpl) AddActivityTask( ) (resp *types.AddActivityTaskResponse, retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() - startT := time.Now() domainName := h.domainName(request.GetDomainUUID()) hCtx := h.newHandlerContext( ctx, @@ -148,12 +146,8 @@ func (h *handlerImpl) AddActivityTask( return nil, hCtx.handleErr(errMatchingHostThrottle) } - syncMatch, err := h.engine.AddActivityTask(hCtx, request) - if syncMatch { - hCtx.scope.RecordTimer(metrics.SyncMatchLatencyPerTaskList, time.Since(startT)) - } - - return &types.AddActivityTaskResponse{}, hCtx.handleErr(err) + resp, err := h.engine.AddActivityTask(hCtx, request) + return resp, hCtx.handleErr(err) } // AddDecisionTask - adds a decision task. @@ -163,7 +157,6 @@ func (h *handlerImpl) AddDecisionTask( ) (resp *types.AddDecisionTaskResponse, retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() - startT := time.Now() domainName := h.domainName(request.GetDomainUUID()) hCtx := h.newHandlerContext( ctx, @@ -183,11 +176,8 @@ func (h *handlerImpl) AddDecisionTask( return nil, hCtx.handleErr(errMatchingHostThrottle) } - syncMatch, err := h.engine.AddDecisionTask(hCtx, request) - if syncMatch { - hCtx.scope.RecordTimer(metrics.SyncMatchLatencyPerTaskList, time.Since(startT)) - } - return &types.AddDecisionTaskResponse{}, hCtx.handleErr(err) + resp, err := h.engine.AddDecisionTask(hCtx, request) + return resp, hCtx.handleErr(err) } // PollForActivityTask - long poll for an activity task. diff --git a/service/matching/handler/handler_test.go b/service/matching/handler/handler_test.go index 9610a15daef..d18fb2d9604 100644 --- a/service/matching/handler/handler_test.go +++ b/service/matching/handler/handler_test.go @@ -164,13 +164,27 @@ func (s *handlerSuite) TestAddActivityTask() { testCases := []struct { name string setupMocks func() + want *types.AddActivityTaskResponse err error }{ { name: "Success case", setupMocks: func() { s.mockLimiter.EXPECT().Allow().Return(true).Times(1) - s.mockEngine.EXPECT().AddActivityTask(gomock.Any(), &request).Return(true, nil).Times(1) + s.mockEngine.EXPECT().AddActivityTask(gomock.Any(), &request).Return(&types.AddActivityTaskResponse{ + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + }, nil).Times(1) + }, + want: &types.AddActivityTaskResponse{ + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, }, }, { @@ -184,7 +198,7 @@ func (s *handlerSuite) TestAddActivityTask() { name: "Error case - AddActivityTask failed", setupMocks: func() { s.mockLimiter.EXPECT().Allow().Return(true).Times(1) // Ensure Allow() returns true - s.mockEngine.EXPECT().AddActivityTask(gomock.Any(), &request).Return(false, errors.New("add-activity-error")).Times(1) + s.mockEngine.EXPECT().AddActivityTask(gomock.Any(), &request).Return(nil, errors.New("add-activity-error")).Times(1) }, err: &types.InternalServiceError{Message: "add-activity-error"}, }, @@ -195,12 +209,13 @@ func (s *handlerSuite) TestAddActivityTask() { tc.setupMocks() s.mockDomainCache.EXPECT().GetDomainName(request.DomainUUID).Return(s.testDomain, nil).Times(1) - _, err := s.handler.AddActivityTask(context.Background(), &request) + resp, err := s.handler.AddActivityTask(context.Background(), &request) if tc.err != nil { s.Error(err) s.Equal(tc.err, err) } else { + s.Equal(tc.want, resp) s.NoError(err) } }) @@ -217,13 +232,27 @@ func (s *handlerSuite) TestAddDecisionTask() { testCases := []struct { name string setupMocks func() + want *types.AddDecisionTaskResponse err error }{ { name: "Success case", setupMocks: func() { s.mockLimiter.EXPECT().Allow().Return(true).Times(1) - s.mockEngine.EXPECT().AddDecisionTask(gomock.Any(), &request).Return(true, nil).Times(1) + s.mockEngine.EXPECT().AddDecisionTask(gomock.Any(), &request).Return(&types.AddDecisionTaskResponse{ + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, + }, nil).Times(1) + }, + want: &types.AddDecisionTaskResponse{ + PartitionConfig: &types.TaskListPartitionConfig{ + Version: 1, + NumReadPartitions: 2, + NumWritePartitions: 2, + }, }, }, { @@ -237,7 +266,7 @@ func (s *handlerSuite) TestAddDecisionTask() { name: "Error case - AddDecisionTask failed", setupMocks: func() { s.mockLimiter.EXPECT().Allow().Return(true).Times(1) // Ensure Allow() returns true - s.mockEngine.EXPECT().AddDecisionTask(gomock.Any(), &request).Return(false, errors.New("add-decision-error")).Times(1) + s.mockEngine.EXPECT().AddDecisionTask(gomock.Any(), &request).Return(nil, errors.New("add-decision-error")).Times(1) }, err: &types.InternalServiceError{Message: "add-decision-error"}, }, @@ -248,12 +277,13 @@ func (s *handlerSuite) TestAddDecisionTask() { tc.setupMocks() s.mockDomainCache.EXPECT().GetDomainName(request.DomainUUID).Return(s.testDomain, nil).Times(1) - _, err := s.handler.AddDecisionTask(context.Background(), &request) + resp, err := s.handler.AddDecisionTask(context.Background(), &request) if tc.err != nil { s.Error(err) s.Equal(tc.err, err) } else { + s.Equal(tc.want, resp) s.NoError(err) } }) diff --git a/service/matching/handler/interfaces.go b/service/matching/handler/interfaces.go index 70c63aabaa2..696dda897d5 100644 --- a/service/matching/handler/interfaces.go +++ b/service/matching/handler/interfaces.go @@ -36,8 +36,8 @@ type ( Engine interface { common.Daemon - AddDecisionTask(hCtx *handlerContext, request *types.AddDecisionTaskRequest) (syncMatch bool, err error) - AddActivityTask(hCtx *handlerContext, request *types.AddActivityTaskRequest) (syncMatch bool, err error) + AddDecisionTask(hCtx *handlerContext, request *types.AddDecisionTaskRequest) (*types.AddDecisionTaskResponse, error) + AddActivityTask(hCtx *handlerContext, request *types.AddActivityTaskRequest) (*types.AddActivityTaskResponse, error) PollForDecisionTask(hCtx *handlerContext, request *types.MatchingPollForDecisionTaskRequest) (*types.MatchingPollForDecisionTaskResponse, error) PollForActivityTask(hCtx *handlerContext, request *types.MatchingPollForActivityTaskRequest) (*types.MatchingPollForActivityTaskResponse, error) QueryWorkflow(hCtx *handlerContext, request *types.MatchingQueryWorkflowRequest) (*types.QueryWorkflowResponse, error) diff --git a/service/matching/handler/interfaces_mock.go b/service/matching/handler/interfaces_mock.go index bdfdb30f13a..e749ccefc7e 100644 --- a/service/matching/handler/interfaces_mock.go +++ b/service/matching/handler/interfaces_mock.go @@ -59,10 +59,10 @@ func (m *MockEngine) EXPECT() *MockEngineMockRecorder { } // AddActivityTask mocks base method. -func (m *MockEngine) AddActivityTask(hCtx *handlerContext, request *types.AddActivityTaskRequest) (bool, error) { +func (m *MockEngine) AddActivityTask(hCtx *handlerContext, request *types.AddActivityTaskRequest) (*types.AddActivityTaskResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddActivityTask", hCtx, request) - ret0, _ := ret[0].(bool) + ret0, _ := ret[0].(*types.AddActivityTaskResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -74,10 +74,10 @@ func (mr *MockEngineMockRecorder) AddActivityTask(hCtx, request interface{}) *go } // AddDecisionTask mocks base method. -func (m *MockEngine) AddDecisionTask(hCtx *handlerContext, request *types.AddDecisionTaskRequest) (bool, error) { +func (m *MockEngine) AddDecisionTask(hCtx *handlerContext, request *types.AddDecisionTaskRequest) (*types.AddDecisionTaskResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddDecisionTask", hCtx, request) - ret0, _ := ret[0].(bool) + ret0, _ := ret[0].(*types.AddDecisionTaskResponse) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/service/matching/tasklist/db.go b/service/matching/tasklist/db.go index eef629eba9f..a29330e515a 100644 --- a/service/matching/tasklist/db.go +++ b/service/matching/tasklist/db.go @@ -32,16 +32,18 @@ import ( type ( taskListDB struct { - sync.Mutex - domainID string - domainName string - taskListName string - taskListKind int - taskType int - rangeID int64 - backlogCount int64 - store persistence.TaskManager - logger log.Logger + sync.RWMutex + domainID string + domainName string + taskListName string + taskListKind int + taskType int + rangeID int64 + backlogCount int64 + ackLevel int64 + partitionConfig *persistence.TaskListPartitionConfig + store persistence.TaskManager + logger log.Logger } taskListState struct { rangeID int64 @@ -73,8 +75,8 @@ func newTaskListDB(store persistence.TaskManager, domainID string, domainName st // RangeID returns the current persistence view of rangeID func (db *taskListDB) RangeID() int64 { - db.Lock() - defer db.Unlock() + db.RLock() + defer db.RUnlock() return db.rangeID } @@ -83,6 +85,12 @@ func (db *taskListDB) BacklogCount() int64 { return atomic.LoadInt64(&db.backlogCount) } +func (db *taskListDB) PartitionConfig() *persistence.TaskListPartitionConfig { + db.RLock() + defer db.RUnlock() + return db.partitionConfig +} + // RenewLease renews the lease on a tasklist. If there is no previous lease, // this method will attempt to steal tasklist from current owner func (db *taskListDB) RenewLease() (taskListState, error) { @@ -100,6 +108,8 @@ func (db *taskListDB) RenewLease() (taskListState, error) { return taskListState{}, err } db.rangeID = resp.TaskListInfo.RangeID + db.ackLevel = resp.TaskListInfo.AckLevel + db.partitionConfig = resp.TaskListInfo.AdaptivePartitionConfig return taskListState{rangeID: db.rangeID, ackLevel: resp.TaskListInfo.AckLevel}, nil } @@ -109,16 +119,21 @@ func (db *taskListDB) UpdateState(ackLevel int64) error { defer db.Unlock() _, err := db.store.UpdateTaskList(context.Background(), &persistence.UpdateTaskListRequest{ TaskListInfo: &persistence.TaskListInfo{ - DomainID: db.domainID, - Name: db.taskListName, - TaskType: db.taskType, - AckLevel: ackLevel, - RangeID: db.rangeID, - Kind: db.taskListKind, + DomainID: db.domainID, + Name: db.taskListName, + TaskType: db.taskType, + AckLevel: ackLevel, + RangeID: db.rangeID, + Kind: db.taskListKind, + AdaptivePartitionConfig: db.partitionConfig, }, DomainName: db.domainName, }) - return err + if err != nil { + return err + } + db.ackLevel = ackLevel + return nil } // CreateTasks creates a batch of given tasks for this task list @@ -150,28 +165,6 @@ func (db *taskListDB) GetTasks(minTaskID int64, maxTaskID int64, batchSize int) }) } -// CompleteTask deletes a single task from this task list -func (db *taskListDB) CompleteTask(taskID int64) error { - err := db.store.CompleteTask(context.Background(), &persistence.CompleteTaskRequest{ - TaskList: &persistence.TaskListInfo{ - DomainID: db.domainID, - Name: db.taskListName, - TaskType: db.taskType, - }, - TaskID: taskID, - DomainName: db.domainName, - }) - if err != nil { - db.logger.Error("Persistent store operation failure", - tag.StoreOperationCompleteTask, - tag.Error(err), - tag.TaskID(taskID), - tag.TaskType(db.taskType), - tag.WorkflowTaskListName(db.taskListName)) - } - return err -} - // CompleteTasksLessThan deletes of tasks less than the given taskID. Limit is // the upper bound of number of tasks that can be deleted by this method. It may // or may not be honored @@ -211,3 +204,16 @@ func (db *taskListDB) GetTaskListSize(ackLevel int64) (int64, error) { atomic.StoreInt64(&db.backlogCount, resp.Size) return resp.Size, nil } + +func (db *taskListDB) GetTaskListInfo(taskListName string) (*persistence.TaskListInfo, error) { + resp, err := db.store.GetTaskList(context.Background(), &persistence.GetTaskListRequest{ + DomainID: db.domainID, + DomainName: db.domainName, + TaskList: taskListName, + TaskType: db.taskType, + }) + if err != nil { + return nil, err + } + return resp.TaskListInfo, nil +} diff --git a/service/matching/tasklist/interfaces.go b/service/matching/tasklist/interfaces.go index e55c5195c8c..741117dfded 100644 --- a/service/matching/tasklist/interfaces.go +++ b/service/matching/tasklist/interfaces.go @@ -57,6 +57,7 @@ type ( String() string GetTaskListKind() types.TaskListKind TaskListID() *Identifier + TaskListPartitionConfig() *types.TaskListPartitionConfig } TaskMatcher interface { diff --git a/service/matching/tasklist/interfaces_mock.go b/service/matching/tasklist/interfaces_mock.go index 5a5986ab2c7..feb505e5c96 100644 --- a/service/matching/tasklist/interfaces_mock.go +++ b/service/matching/tasklist/interfaces_mock.go @@ -240,6 +240,20 @@ func (mr *MockManagerMockRecorder) TaskListID() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TaskListID", reflect.TypeOf((*MockManager)(nil).TaskListID)) } +// TaskListPartitionConfig mocks base method. +func (m *MockManager) TaskListPartitionConfig() *types.TaskListPartitionConfig { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TaskListPartitionConfig") + ret0, _ := ret[0].(*types.TaskListPartitionConfig) + return ret0 +} + +// TaskListPartitionConfig indicates an expected call of TaskListPartitionConfig. +func (mr *MockManagerMockRecorder) TaskListPartitionConfig() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TaskListPartitionConfig", reflect.TypeOf((*MockManager)(nil).TaskListPartitionConfig)) +} + // MockTaskMatcher is a mock of TaskMatcher interface. type MockTaskMatcher struct { ctrl *gomock.Controller diff --git a/service/matching/tasklist/task_list_manager.go b/service/matching/tasklist/task_list_manager.go index 6d9ad5be311..c188a4f0748 100644 --- a/service/matching/tasklist/task_list_manager.go +++ b/service/matching/tasklist/task_list_manager.go @@ -121,6 +121,9 @@ type ( closeCallback func(Manager) qpsTracker stats.QPSTracker + + partitionConfigLock sync.RWMutex + partitionConfig *types.TaskListPartitionConfig } ) @@ -213,11 +216,22 @@ func NewManager( func (c *taskListManagerImpl) Start() error { defer c.startWG.Done() - c.liveness.Start() if err := c.taskWriter.Start(); err != nil { c.Stop() return err } + c.reloadTaskListPartitionConfig() + if c.taskListID.IsRoot() && c.taskListKind != types.TaskListKindSticky { + c.partitionConfig = c.db.PartitionConfig().ToInternalType() + if c.partitionConfig == nil { + c.partitionConfig = &types.TaskListPartitionConfig{ + Version: 0, + NumReadPartitions: 1, + NumWritePartitions: 1, + } + } + } + c.liveness.Start() c.taskReader.Start() c.qpsTracker.Start() @@ -253,6 +267,46 @@ func (c *taskListManagerImpl) handleErr(err error) error { return err } +func (c *taskListManagerImpl) reloadTaskListPartitionConfig() { + if c.taskListID.IsRoot() { + return + } + c.partitionConfigLock.RLock() + if c.partitionConfig != nil { + c.partitionConfigLock.RUnlock() + return + } + c.partitionConfigLock.RUnlock() + + c.partitionConfigLock.Lock() + if c.partitionConfig != nil { + c.partitionConfigLock.Unlock() + return + } + defer c.partitionConfigLock.Unlock() + info, err := c.db.GetTaskListInfo(c.taskListID.GetRoot()) + if err != nil { + // Given current set up, it's possible that the root partition is created after non-root partition + // In this case, we don't fail the start for now, but set the config to nil. + // We'll check if the field is nil, if it is, we'll reload it from database on demand. + return + } + c.partitionConfig = info.AdaptivePartitionConfig.ToInternalType() + if c.partitionConfig == nil { + c.partitionConfig = &types.TaskListPartitionConfig{ + Version: 0, + NumReadPartitions: 1, + NumWritePartitions: 1, + } + } +} + +func (c *taskListManagerImpl) TaskListPartitionConfig() *types.TaskListPartitionConfig { + c.partitionConfigLock.RLock() + defer c.partitionConfigLock.RUnlock() + return c.partitionConfig +} + // AddTask adds a task to the task list. This method will first attempt a synchronous // match with a poller. When there are no pollers or if rate limit is exceeded, task will // be written to database and later asynchronously matched with a poller @@ -263,6 +317,7 @@ func (c *taskListManagerImpl) AddTask(ctx context.Context, params AddTaskParams) c.Stop() return false, errShutdown } + c.reloadTaskListPartitionConfig() if params.ForwardedFrom == "" { // request sent by history service c.liveness.MarkAlive() @@ -349,6 +404,7 @@ func (c *taskListManagerImpl) DispatchQueryTask( request *types.MatchingQueryWorkflowRequest, ) (*types.QueryWorkflowResponse, error) { c.startWG.Wait() + c.reloadTaskListPartitionConfig() task := newInternalQueryTask(taskID, request) return c.matcher.OfferQuery(ctx, task) } @@ -366,6 +422,7 @@ func (c *taskListManagerImpl) GetTask( return nil, ErrNoTasks } c.liveness.MarkAlive() + c.reloadTaskListPartitionConfig() // TODO: consider return early if QPS and backlog count are both 0, // since there is no task to be returned task, err := c.getTask(ctx, maxDispatchPerSecond) diff --git a/service/matching/tasklist/task_list_manager_test.go b/service/matching/tasklist/task_list_manager_test.go index e3948208b65..0fef0bd4b0e 100644 --- a/service/matching/tasklist/task_list_manager_test.go +++ b/service/matching/tasklist/task_list_manager_test.go @@ -928,3 +928,56 @@ func TestTaskListManagerImpl_HasPollerAfter(t *testing.T) { func getIsolationgroupsHelper() []string { return []string{"datacenterA", "datacenterB"} } + +func TestReloadTaskListPartitionConfig(t *testing.T) { + ctrl := gomock.NewController(t) + mockPartitioner := partition.NewMockPartitioner(ctrl) + mockPartitioner.EXPECT().GetIsolationGroupByDomainID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return("", nil).AnyTimes() + mockDomainCache := cache.NewMockDomainCache(ctrl) + mockDomainCache.EXPECT().GetDomainByID(gomock.Any()).Return(cache.CreateDomainCacheEntry("domainName"), nil).AnyTimes() + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("domainName", nil).AnyTimes() + + mockTm := persistence.NewMockTaskManager(ctrl) + mockTm.EXPECT().GetTaskList(gomock.Any(), &persistence.GetTaskListRequest{ + DomainID: "domain", + DomainName: "domainName", + TaskList: "tasklist", + TaskType: persistence.TaskListTypeActivity, + }).Return(nil, errors.New("error")).Times(1) + mockTm.EXPECT().GetTaskList(gomock.Any(), &persistence.GetTaskListRequest{ + DomainID: "domain", + DomainName: "domainName", + TaskList: "tasklist", + TaskType: persistence.TaskListTypeActivity, + }).Return(&persistence.GetTaskListResponse{ + TaskListInfo: &persistence.TaskListInfo{ + AdaptivePartitionConfig: nil, + }, + }, nil).Times(1) + + tlID, err := NewIdentifier("domain", "/__cadence_sys/tasklist/1", persistence.TaskListTypeActivity) + require.NoError(t, err) + + tlMgr, err := NewManager( + mockDomainCache, + testlogger.New(t), + metrics.NewClient(tally.NoopScope, metrics.Matching), + mockTm, + cluster.GetTestClusterMetadata(true), + mockPartitioner, + nil, + func(Manager) {}, + tlID, + types.TaskListKindNormal.Ptr(), + defaultTestConfig(), + clock.NewRealTimeSource(), + time.Now()) + + tlm := tlMgr.(*taskListManagerImpl) + tlm.reloadTaskListPartitionConfig() + assert.Nil(t, tlm.TaskListPartitionConfig()) + tlm.reloadTaskListPartitionConfig() + assert.Equal(t, &types.TaskListPartitionConfig{NumReadPartitions: 1, NumWritePartitions: 1}, tlm.TaskListPartitionConfig()) + tlm.reloadTaskListPartitionConfig() + assert.Equal(t, &types.TaskListPartitionConfig{NumReadPartitions: 1, NumWritePartitions: 1}, tlm.TaskListPartitionConfig()) +} diff --git a/service/matching/tasklist/testing.go b/service/matching/tasklist/testing.go index 87c3a8f7cad..0d2132cc6c5 100644 --- a/service/matching/tasklist/testing.go +++ b/service/matching/tasklist/testing.go @@ -49,11 +49,13 @@ type ( } testTaskListManager struct { - sync.Mutex - rangeID int64 - ackLevel int64 - createTaskCount int - tasks *treemap.Map + sync.RWMutex + rangeID int64 + ackLevel int64 + kind int + createTaskCount int + tasks *treemap.Map + adaptivePartitionConfig *persistence.TaskListPartitionConfig } ) @@ -86,16 +88,18 @@ func (m *TestTaskManager) LeaseTaskList( tlm.Lock() defer tlm.Unlock() tlm.rangeID++ + tlm.kind = request.TaskListKind m.logger.Debug(fmt.Sprintf("testTaskManager.LeaseTaskList rangeID=%v", tlm.rangeID)) return &persistence.LeaseTaskListResponse{ TaskListInfo: &persistence.TaskListInfo{ - AckLevel: tlm.ackLevel, - DomainID: request.DomainID, - Name: request.TaskList, - TaskType: request.TaskType, - RangeID: tlm.rangeID, - Kind: request.TaskListKind, + AckLevel: tlm.ackLevel, + DomainID: request.DomainID, + Name: request.TaskList, + TaskType: request.TaskType, + RangeID: tlm.rangeID, + Kind: tlm.kind, + AdaptivePartitionConfig: tlm.adaptivePartitionConfig, }, }, nil } @@ -104,7 +108,20 @@ func (m *TestTaskManager) GetTaskList( _ context.Context, request *persistence.GetTaskListRequest, ) (*persistence.GetTaskListResponse, error) { - return nil, fmt.Errorf("not implemented") + tlm := m.getTaskListManager(NewTestTaskListID(m.t, request.DomainID, request.TaskList, request.TaskType)) + tlm.RLock() + defer tlm.RUnlock() + return &persistence.GetTaskListResponse{ + TaskListInfo: &persistence.TaskListInfo{ + AckLevel: tlm.ackLevel, + DomainID: request.DomainID, + Name: request.TaskList, + TaskType: request.TaskType, + RangeID: tlm.rangeID, + Kind: tlm.kind, + AdaptivePartitionConfig: tlm.adaptivePartitionConfig, + }, + }, nil } // UpdateTaskList provides a mock function with given fields: ctx, request