Skip to content

Commit

Permalink
Add unit tests for matching engine
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaddoll committed Jun 5, 2024
1 parent b38bd0c commit c98a6d2
Show file tree
Hide file tree
Showing 5 changed files with 937 additions and 50 deletions.
38 changes: 17 additions & 21 deletions service/matching/handler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ type (
membershipResolver membership.Resolver
partitioner partition.Partitioner
timeSource clock.TimeSource

waitForQueryResultFn func(hCtx *handlerContext, isStrongConsistencyQuery bool, queryResultCh <-chan *queryResult) (*types.QueryWorkflowResponse, error)
}

// HistoryInfo consists of two integer regarding the history size and history count
Expand Down Expand Up @@ -128,7 +130,7 @@ func NewEngine(taskManager persistence.TaskManager,
partitioner partition.Partitioner,
timeSource clock.TimeSource,
) Engine {
return &matchingEngineImpl{
e := &matchingEngineImpl{
taskManager: taskManager,
clusterMetadata: clusterMetadata,
historyService: historyService,
Expand All @@ -145,6 +147,8 @@ func NewEngine(taskManager persistence.TaskManager,
partitioner: partitioner,
timeSource: timeSource,
}
e.waitForQueryResultFn = e.waitForQueryResult
return e
}

func (e *matchingEngineImpl) Start() {
Expand Down Expand Up @@ -249,15 +253,14 @@ func (e *matchingEngineImpl) getTaskListByDomainLocked(domainID string) *types.G
decisionTaskListMap := make(map[string]*types.DescribeTaskListResponse)
activityTaskListMap := make(map[string]*types.DescribeTaskListResponse)
for tl, tlm := range e.taskLists {
if tlm.GetTaskListKind() == types.TaskListKindNormal && tl.GetDomainID() == domainID {
if tl.GetDomainID() == domainID && tlm.GetTaskListKind() == types.TaskListKindNormal {
if types.TaskListType(tl.GetType()) == types.TaskListTypeDecision {
decisionTaskListMap[tl.GetRoot()] = tlm.DescribeTaskList(false)
} else {
activityTaskListMap[tl.GetRoot()] = tlm.DescribeTaskList(false)
}
// TODO: review this logic
activityTaskListMap[tl.GetRoot()] = tlm.DescribeTaskList(false)
}
}

return &types.GetTaskListsByDomainResponse{
DecisionTaskListMap: decisionTaskListMap,
ActivityTaskListMap: activityTaskListMap,
Expand Down Expand Up @@ -712,23 +715,24 @@ func (e *matchingEngineImpl) QueryWorkflow(
queryResultCh := make(chan *queryResult, 1)
e.lockableQueryTaskMap.put(taskID, queryResultCh)
defer e.lockableQueryTaskMap.delete(taskID)
return e.waitForQueryResultFn(hCtx, queryRequest.GetQueryRequest().GetQueryConsistencyLevel() == types.QueryConsistencyLevelStrong, queryResultCh)
}

func (e *matchingEngineImpl) waitForQueryResult(hCtx *handlerContext, isStrongConsistencyQuery bool, queryResultCh <-chan *queryResult) (*types.QueryWorkflowResponse, error) {
select {
case result := <-queryResultCh:
if result.internalError != nil {
return nil, result.internalError
}

workerResponse := result.workerResponse
// if query was intended as consistent query check to see if worker supports consistent query
if queryRequest.GetQueryRequest().GetQueryConsistencyLevel() == types.QueryConsistencyLevelStrong {
if isStrongConsistencyQuery {
if err := e.versionChecker.SupportsConsistentQuery(
workerResponse.GetCompletedRequest().GetWorkerVersionInfo().GetImpl(),
workerResponse.GetCompletedRequest().GetWorkerVersionInfo().GetFeatureVersion()); err != nil {
return nil, err
}
}

switch workerResponse.GetCompletedRequest().GetCompletedType() {
case types.QueryTaskCompletedTypeCompleted:
return &types.QueryWorkflowResponse{QueryResult: workerResponse.GetCompletedRequest().GetQueryResult()}, nil
Expand Down Expand Up @@ -878,30 +882,22 @@ func (e *matchingEngineImpl) getAllPartitions(
request *types.MatchingListTaskListPartitionsRequest,
taskListType int,
) ([]string, error) {
var partitionKeys []string
domainID, err := e.domainCache.GetDomainID(request.GetDomain())
if err != nil {
return partitionKeys, err
return nil, err
}
taskList := request.GetTaskList()
taskListID, err := tasklist.NewIdentifier(domainID, taskList.GetName(), taskListType)
if err != nil {
return partitionKeys, err
}
rootPartition := taskListID.GetRoot()

partitionKeys = append(partitionKeys, rootPartition)

nWritePartitions := e.config.NumTasklistWritePartitions
n := nWritePartitions(request.GetDomain(), rootPartition, taskListType)
if n <= 0 {
return partitionKeys, nil
return nil, err
}

rootPartition := taskListID.GetRoot()
partitionKeys := []string{rootPartition}
n := e.config.NumTasklistWritePartitions(request.GetDomain(), rootPartition, taskListType)
for i := 1; i < n; i++ {
partitionKeys = append(partitionKeys, fmt.Sprintf("%v%v/%v", common.ReservedTaskListPrefix, rootPartition, i))
}

return partitionKeys, nil
}

Expand Down
Loading

0 comments on commit c98a6d2

Please sign in to comment.