Skip to content

Commit

Permalink
Update matching handlers to return TaskListPartitionConfig in responses
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaddoll committed Oct 19, 2024
1 parent 9b75b7d commit f3389b9
Show file tree
Hide file tree
Showing 15 changed files with 382 additions and 116 deletions.
11 changes: 11 additions & 0 deletions common/persistence/data_manager_interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -2054,3 +2054,14 @@ func (e *WorkflowExecutionInfo) CopyPartitionConfig() map[string]string {
}
return partitionConfig
}

func (p *TaskListPartitionConfig) ToInternalType() *types.TaskListPartitionConfig {
if p == nil {
return nil
}
return &types.TaskListPartitionConfig{
Version: p.Version,
NumReadPartitions: int32(p.NumReadPartitions),
NumWritePartitions: int32(p.NumWritePartitions),
}
}
38 changes: 38 additions & 0 deletions common/persistence/data_manager_interfaces_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,41 @@ func TestCopyPartitionConfig(t *testing.T) {
})
}
}

func TestTaskListPartitionConfigToInternalType(t *testing.T) {
testCases := []struct {
name string
input *TaskListPartitionConfig
expect *types.TaskListPartitionConfig
}{
{
name: "nil case",
input: nil,
expect: nil,
},
{
name: "empty case",
input: &TaskListPartitionConfig{},
expect: &types.TaskListPartitionConfig{},
},
{
name: "normal case",
input: &TaskListPartitionConfig{
Version: 1,
NumReadPartitions: 2,
NumWritePartitions: 3,
},
expect: &types.TaskListPartitionConfig{
Version: 1,
NumReadPartitions: 2,
NumWritePartitions: 3,
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expect, tc.input.ToInternalType())
})
}
}
4 changes: 2 additions & 2 deletions service/history/task/transfer_active_task_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,8 @@ func (s *transferActiveTaskExecutorSuite) TestProcessDecisionTask_StickyWorkerUn
}

gomock.InOrder(
s.mockMatchingClient.EXPECT().AddDecisionTask(gomock.Any(), addDecisionTaskRequest).Return(&types.StickyWorkerUnavailableError{}).Times(1),
s.mockMatchingClient.EXPECT().AddDecisionTask(gomock.Any(), gomock.Eq(&modifiedRequest)).Return(nil).Times(1),
s.mockMatchingClient.EXPECT().AddDecisionTask(gomock.Any(), addDecisionTaskRequest).Return(nil, &types.StickyWorkerUnavailableError{}).Times(1),
s.mockMatchingClient.EXPECT().AddDecisionTask(gomock.Any(), gomock.Eq(&modifiedRequest)).Return(&types.AddDecisionTaskResponse{}, nil).Times(1),
)

err = s.transferActiveTaskExecutor.Execute(transferTask, true)
Expand Down
89 changes: 60 additions & 29 deletions service/matching/handler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ func (e *matchingEngineImpl) removeTaskListManager(tlMgr tasklist.Manager) {
func (e *matchingEngineImpl) AddDecisionTask(
hCtx *handlerContext,
request *types.AddDecisionTaskRequest,
) (bool, error) {
) (*types.AddDecisionTaskResponse, error) {
startT := time.Now()
domainID := request.GetDomainUUID()
taskListName := request.GetTaskList().GetName()
taskListKind := request.GetTaskList().Kind
Expand Down Expand Up @@ -355,13 +356,13 @@ func (e *matchingEngineImpl) AddDecisionTask(

taskListID, err := tasklist.NewIdentifier(domainID, taskListName, taskListType)
if err != nil {
return false, err
return nil, err
}

// get the domainName
domainName, err := e.domainCache.GetDomainName(domainID)
if err != nil {
return false, err
return nil, err
}

// Only emit traffic metrics if the tasklist is not sticky and is not forwarded
Expand All @@ -376,13 +377,13 @@ func (e *matchingEngineImpl) AddDecisionTask(

tlMgr, err := e.getTaskListManager(taskListID, taskListKind)
if err != nil {
return false, err
return nil, err
}

if taskListKind != nil && *taskListKind == types.TaskListKindSticky {
// check if the sticky worker is still available, if not, fail this request early
if !tlMgr.HasPollerAfter(e.timeSource.Now().Add(-_stickyPollerUnavailableWindow)) {
return false, _stickyPollerUnavailableError
return nil, _stickyPollerUnavailableError
}
}

Expand All @@ -396,18 +397,28 @@ func (e *matchingEngineImpl) AddDecisionTask(
PartitionConfig: request.GetPartitionConfig(),
}

return tlMgr.AddTask(hCtx.Context, tasklist.AddTaskParams{
syncMatched, err := tlMgr.AddTask(hCtx.Context, tasklist.AddTaskParams{
TaskInfo: taskInfo,
Source: request.GetSource(),
ForwardedFrom: request.GetForwardedFrom(),
})
if err != nil {
return nil, err
}
if syncMatched {
hCtx.scope.RecordTimer(metrics.SyncMatchLatencyPerTaskList, time.Since(startT))
}
return &types.AddDecisionTaskResponse{
PartitionConfig: tlMgr.TaskListPartitionConfig(),
}, nil
}

// AddActivityTask either delivers task directly to waiting poller or save it into task list persistence.
func (e *matchingEngineImpl) AddActivityTask(
hCtx *handlerContext,
request *types.AddActivityTaskRequest,
) (bool, error) {
) (*types.AddActivityTaskResponse, error) {
startT := time.Now()
domainID := request.GetDomainUUID()
taskListName := request.GetTaskList().GetName()
taskListKind := request.GetTaskList().Kind
Expand All @@ -427,13 +438,13 @@ func (e *matchingEngineImpl) AddActivityTask(

taskListID, err := tasklist.NewIdentifier(domainID, taskListName, taskListType)
if err != nil {
return false, err
return nil, err
}

// get the domainName
domainName, err := e.domainCache.GetDomainName(domainID)
if err != nil {
return false, err
return nil, err
}

// Only emit traffic metrics if the tasklist is not sticky and is not forwarded
Expand All @@ -448,7 +459,7 @@ func (e *matchingEngineImpl) AddActivityTask(

tlMgr, err := e.getTaskListManager(taskListID, taskListKind)
if err != nil {
return false, err
return nil, err
}

taskInfo := &persistence.TaskInfo{
Expand All @@ -461,12 +472,21 @@ func (e *matchingEngineImpl) AddActivityTask(
PartitionConfig: request.GetPartitionConfig(),
}

return tlMgr.AddTask(hCtx.Context, tasklist.AddTaskParams{
syncMatched, err := tlMgr.AddTask(hCtx.Context, tasklist.AddTaskParams{
TaskInfo: taskInfo,
Source: request.GetSource(),
ForwardedFrom: request.GetForwardedFrom(),
ActivityTaskDispatchInfo: request.ActivityTaskDispatchInfo,
})
if err != nil {
return nil, err
}
if syncMatched {
hCtx.scope.RecordTimer(metrics.SyncMatchLatencyPerTaskList, time.Since(startT))
}
return &types.AddActivityTaskResponse{
PartitionConfig: tlMgr.TaskListPartitionConfig(),
}, nil
}

// PollForDecisionTask tries to get the decision task using exponential backoff.
Expand Down Expand Up @@ -512,7 +532,11 @@ pollLoop:
pollerCtx := tasklist.ContextWithPollerID(hCtx.Context, pollerID)
pollerCtx = tasklist.ContextWithIdentity(pollerCtx, request.GetIdentity())
pollerCtx = tasklist.ContextWithIsolationGroup(pollerCtx, req.GetIsolationGroup())
task, err := e.getTask(pollerCtx, taskListID, nil, taskListKind)
tlMgr, err := e.getTaskListManager(taskListID, taskListKind)
if err != nil {
return nil, fmt.Errorf("couldn't load tasklist namanger: %w", err)
}
task, err := tlMgr.GetTask(pollerCtx, nil)
if err != nil {
// TODO: Is empty poll the best reply for errPumpClosed?
if errors.Is(err, tasklist.ErrNoTasks) || errors.Is(err, errPumpClosed) {
Expand All @@ -534,7 +558,9 @@ pollLoop:
"RequestForwardedFrom": req.GetForwardedFrom(),
},
})
return emptyPollForDecisionTaskResponse, nil
return &types.MatchingPollForDecisionTaskResponse{
PartitionConfig: tlMgr.TaskListPartitionConfig(),
}, nil
}
return nil, fmt.Errorf("couldn't get task: %w", err)
}
Expand Down Expand Up @@ -568,7 +594,9 @@ pollLoop:
if err != nil {
// will notify query client that the query task failed
e.deliverQueryResult(task.Query.TaskID, &queryResult{internalError: err}) //nolint:errcheck
return emptyPollForDecisionTaskResponse, nil
return &types.MatchingPollForDecisionTaskResponse{
PartitionConfig: tlMgr.TaskListPartitionConfig(),
}, nil
}

isStickyEnabled := false
Expand All @@ -585,7 +613,7 @@ pollLoop:
BranchToken: mutableStateResp.CurrentBranchToken,
HistorySize: mutableStateResp.HistorySize,
}
return e.createPollForDecisionTaskResponse(task, resp, hCtx.scope), nil
return e.createPollForDecisionTaskResponse(task, resp, hCtx.scope, tlMgr.TaskListPartitionConfig()), nil
}

e.emitTaskIsolationMetrics(hCtx.scope, task.Event.PartitionConfig, req.GetIsolationGroup())
Expand Down Expand Up @@ -641,7 +669,7 @@ pollLoop:
},
})

return e.createPollForDecisionTaskResponse(task, resp, hCtx.scope), nil
return e.createPollForDecisionTaskResponse(task, resp, hCtx.scope, tlMgr.TaskListPartitionConfig()), nil
}
}

Expand Down Expand Up @@ -683,11 +711,17 @@ pollLoop:
pollerCtx = tasklist.ContextWithIdentity(pollerCtx, request.GetIdentity())
pollerCtx = tasklist.ContextWithIsolationGroup(pollerCtx, req.GetIsolationGroup())
taskListKind := request.TaskList.Kind
task, err := e.getTask(pollerCtx, taskListID, maxDispatch, taskListKind)
tlMgr, err := e.getTaskListManager(taskListID, taskListKind)
if err != nil {
return nil, fmt.Errorf("couldn't load tasklist namanger: %w", err)
}
task, err := tlMgr.GetTask(pollerCtx, maxDispatch)
if err != nil {
// TODO: Is empty poll the best reply for errPumpClosed?
if errors.Is(err, tasklist.ErrNoTasks) || errors.Is(err, errPumpClosed) {
return emptyPollForActivityTaskResponse, nil
return &types.MatchingPollForActivityTaskResponse{
PartitionConfig: tlMgr.TaskListPartitionConfig(),
}, nil
}
e.logger.Error("Received unexpected err while getting task",
tag.WorkflowTaskListName(taskListName),
Expand All @@ -705,7 +739,7 @@ pollLoop:
e.emitTaskIsolationMetrics(hCtx.scope, task.Event.PartitionConfig, req.GetIsolationGroup())
if task.ActivityTaskDispatchInfo != nil {
task.Finish(nil)
return e.createSyncMatchPollForActivityTaskResponse(task, task.ActivityTaskDispatchInfo), nil
return e.createSyncMatchPollForActivityTaskResponse(task, task.ActivityTaskDispatchInfo, tlMgr.TaskListPartitionConfig()), nil
}

resp, err := e.recordActivityTaskStarted(hCtx.Context, request, task)
Expand Down Expand Up @@ -737,13 +771,14 @@ pollLoop:
continue pollLoop
}
task.Finish(nil)
return e.createPollForActivityTaskResponse(task, resp, hCtx.scope), nil
return e.createPollForActivityTaskResponse(task, resp, hCtx.scope, tlMgr.TaskListPartitionConfig()), nil
}
}

func (e *matchingEngineImpl) createSyncMatchPollForActivityTaskResponse(
task *tasklist.InternalTask,
activityTaskDispatchInfo *types.ActivityTaskDispatchInfo,
partitionConfig *types.TaskListPartitionConfig,
) *types.MatchingPollForActivityTaskResponse {

scheduledEvent := activityTaskDispatchInfo.ScheduledEvent
Expand Down Expand Up @@ -777,6 +812,7 @@ func (e *matchingEngineImpl) createSyncMatchPollForActivityTaskResponse(
response.HeartbeatDetails = activityTaskDispatchInfo.HeartbeatDetails
response.WorkflowType = activityTaskDispatchInfo.WorkflowType
response.WorkflowDomain = activityTaskDispatchInfo.WorkflowDomain
response.PartitionConfig = partitionConfig
return response
}

Expand Down Expand Up @@ -1006,15 +1042,6 @@ func (e *matchingEngineImpl) getAllPartitions(
return partitionKeys, nil
}

// Loads a task from persistence and wraps it in a task context
func (e *matchingEngineImpl) getTask(ctx context.Context, taskList *tasklist.Identifier, maxDispatchPerSecond *float64, taskListKind *types.TaskListKind) (*tasklist.InternalTask, error) {
tlMgr, err := e.getTaskListManager(taskList, taskListKind)
if err != nil {
return nil, fmt.Errorf("couldn't load tasklist namanger: %w", err)
}
return tlMgr.GetTask(ctx, maxDispatchPerSecond)
}

func (e *matchingEngineImpl) unloadTaskList(tlMgr tasklist.Manager) {
id := tlMgr.TaskListID()
e.taskListsLock.Lock()
Expand All @@ -1033,6 +1060,7 @@ func (e *matchingEngineImpl) createPollForDecisionTaskResponse(
task *tasklist.InternalTask,
historyResponse *types.RecordDecisionTaskStartedResponse,
scope metrics.Scope,
partitionConfig *types.TaskListPartitionConfig,
) *types.MatchingPollForDecisionTaskResponse {

var token []byte
Expand Down Expand Up @@ -1067,6 +1095,7 @@ func (e *matchingEngineImpl) createPollForDecisionTaskResponse(
response.Query = task.Query.Request.QueryRequest.Query
}
response.BacklogCountHint = task.BacklogCountHint
response.PartitionConfig = partitionConfig
return response
}

Expand All @@ -1075,6 +1104,7 @@ func (e *matchingEngineImpl) createPollForActivityTaskResponse(
task *tasklist.InternalTask,
historyResponse *types.RecordActivityTaskStartedResponse,
scope metrics.Scope,
partitionConfig *types.TaskListPartitionConfig,
) *types.MatchingPollForActivityTaskResponse {

scheduledEvent := historyResponse.ScheduledEvent
Expand Down Expand Up @@ -1118,6 +1148,7 @@ func (e *matchingEngineImpl) createPollForActivityTaskResponse(
response.HeartbeatDetails = historyResponse.HeartbeatDetails
response.WorkflowType = historyResponse.WorkflowType
response.WorkflowDomain = historyResponse.WorkflowDomain
response.PartitionConfig = partitionConfig
return response
}

Expand Down
Loading

0 comments on commit f3389b9

Please sign in to comment.