Skip to content

Commit

Permalink
Fix context leak in tests (#5377)
Browse files Browse the repository at this point in the history
* Comment: Added a call to defer to fix a possible context leak in code.

* Comment: Added deferred cancel outside createContext method.

---------

Co-authored-by: OpenRefactory, Inc <56681071+openrefactory@users.noreply.github.com>
Co-authored-by: Zijian <Shaddoll@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 14, 2024
1 parent 373aacd commit 346def2
Show file tree
Hide file tree
Showing 19 changed files with 975 additions and 332 deletions.
64 changes: 48 additions & 16 deletions host/activity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ func (s *IntegrationSuite) TestActivityHeartBeatWorkflow_Success() {
Identity: identity,
}

we, err0 := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err0 := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err0)

s.Logger.Info("StartWorkflowExecution", tag.WorkflowRunID(we.RunID))
Expand Down Expand Up @@ -117,8 +119,10 @@ func (s *IntegrationSuite) TestActivityHeartBeatWorkflow_Success() {
s.Equal(activityName, activityType.Name)
for i := 0; i < 10; i++ {
s.Logger.Info("Heartbeating for activity", tag.WorkflowActivityID(activityID), tag.Counter(i))
_, err := s.engine.RecordActivityTaskHeartbeat(createContext(), &types.RecordActivityTaskHeartbeatRequest{
ctx, cancel := createContext()
_, err := s.engine.RecordActivityTaskHeartbeat(ctx, &types.RecordActivityTaskHeartbeatRequest{
TaskToken: taskToken, Details: []byte("details")})
cancel()
s.Nil(err)
time.Sleep(10 * time.Millisecond)
}
Expand Down Expand Up @@ -188,7 +192,9 @@ func (s *IntegrationSuite) TestActivityHeartbeatDetailsDuringRetry() {
Identity: identity,
}

we, err0 := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err0 := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err0)

s.Logger.Info("StartWorkflowExecution", tag.WorkflowRunID(we.RunID))
Expand Down Expand Up @@ -244,8 +250,10 @@ func (s *IntegrationSuite) TestActivityHeartbeatDetailsDuringRetry() {
var err error
if activityExecutedCount == 0 {
s.Logger.Info("Heartbeating for activity:", tag.WorkflowActivityID(activityID))
_, err = s.engine.RecordActivityTaskHeartbeat(createContext(), &types.RecordActivityTaskHeartbeatRequest{
ctx, cancel := createContext()
_, err = s.engine.RecordActivityTaskHeartbeat(ctx, &types.RecordActivityTaskHeartbeatRequest{
TaskToken: taskToken, Details: heartbeatDetails})
cancel()
s.Nil(err)
// Trigger heartbeat timeout and retry
time.Sleep(time.Second * 2)
Expand All @@ -270,7 +278,9 @@ func (s *IntegrationSuite) TestActivityHeartbeatDetailsDuringRetry() {
}

describeWorkflowExecution := func() (*types.DescribeWorkflowExecutionResponse, error) {
return s.engine.DescribeWorkflowExecution(createContext(), &types.DescribeWorkflowExecutionRequest{
ctx, cancel := createContext()
defer cancel()
return s.engine.DescribeWorkflowExecution(ctx, &types.DescribeWorkflowExecutionRequest{
Domain: s.domainName,
Execution: &types.WorkflowExecution{
WorkflowID: id,
Expand Down Expand Up @@ -364,7 +374,9 @@ func (s *IntegrationSuite) TestActivityRetry() {
Identity: identity,
}

we, err0 := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err0 := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err0)

s.Logger.Info("StartWorkflowExecution", tag.WorkflowRunID(we.RunID))
Expand Down Expand Up @@ -489,7 +501,9 @@ func (s *IntegrationSuite) TestActivityRetry() {
}

describeWorkflowExecution := func() (*types.DescribeWorkflowExecutionResponse, error) {
return s.engine.DescribeWorkflowExecution(createContext(), &types.DescribeWorkflowExecutionRequest{
ctx, cancel := createContext()
defer cancel()
return s.engine.DescribeWorkflowExecution(ctx, &types.DescribeWorkflowExecutionRequest{
Domain: s.domainName,
Execution: &types.WorkflowExecution{
WorkflowID: id,
Expand Down Expand Up @@ -577,7 +591,9 @@ func (s *IntegrationSuite) TestActivityHeartBeatWorkflow_Timeout() {
Identity: identity,
}

we, err0 := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err0 := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err0)

s.Logger.Info("StartWorkflowExecution", tag.WorkflowRunID(we.RunID))
Expand Down Expand Up @@ -681,7 +697,9 @@ func (s *IntegrationSuite) TestActivityTimeouts() {
Identity: identity,
}

we, err0 := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err0 := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err0)

s.Logger.Info("StartWorkflowExecution", tag.WorkflowRunID(we.RunID))
Expand Down Expand Up @@ -848,8 +866,10 @@ func (s *IntegrationSuite) TestActivityTimeouts() {
go func() {
for i := 0; i < 6; i++ {
s.Logger.Info("Heartbeating for activity", tag.WorkflowActivityID(activityID), tag.Counter(i))
_, err := s.engine.RecordActivityTaskHeartbeat(createContext(), &types.RecordActivityTaskHeartbeatRequest{
ctx, cancel := createContext()
_, err := s.engine.RecordActivityTaskHeartbeat(ctx, &types.RecordActivityTaskHeartbeatRequest{
TaskToken: taskToken, Details: []byte(strconv.Itoa(i))})
cancel()
s.Nil(err)
time.Sleep(1 * time.Second)
}
Expand Down Expand Up @@ -923,7 +943,9 @@ func (s *IntegrationSuite) TestActivityHeartbeatTimeouts() {
Identity: identity,
}

we, err0 := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err0 := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err0)

s.Logger.Info("StartWorkflowExecution", tag.WorkflowRunID(we.RunID))
Expand Down Expand Up @@ -1028,8 +1050,10 @@ func (s *IntegrationSuite) TestActivityHeartbeatTimeouts() {
for i := 0; i < 10; i++ {
if !workflowComplete {
s.Logger.Info("Heartbeating for activity", tag.WorkflowActivityID(activityID), tag.Counter(i))
_, err := s.engine.RecordActivityTaskHeartbeat(createContext(), &types.RecordActivityTaskHeartbeatRequest{
ctx, cancel := createContext()
_, err := s.engine.RecordActivityTaskHeartbeat(ctx, &types.RecordActivityTaskHeartbeatRequest{
TaskToken: taskToken, Details: []byte(strconv.Itoa(i))})
cancel()
if err != nil {
s.Logger.Error("Activity heartbeat failed", tag.WorkflowActivityID(activityID), tag.Counter(i), tag.Error(err))
}
Expand Down Expand Up @@ -1114,7 +1138,9 @@ func (s *IntegrationSuite) TestActivityCancellation() {
Identity: identity,
}

we, err0 := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err0 := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err0)

s.Logger.Info("StartWorkflowExecution: response", tag.WorkflowRunID(we.GetRunID()))
Expand Down Expand Up @@ -1171,9 +1197,11 @@ func (s *IntegrationSuite) TestActivityCancellation() {
s.Equal(activityName, activityType.GetName())
for i := 0; i < 10; i++ {
s.Logger.Info("Heartbeating for activity", tag.WorkflowActivityID(activityID), tag.Counter(i))
response, err := s.engine.RecordActivityTaskHeartbeat(createContext(),
ctx, cancel := createContext()
response, err := s.engine.RecordActivityTaskHeartbeat(ctx,
&types.RecordActivityTaskHeartbeatRequest{
TaskToken: taskToken, Details: []byte("details")})
cancel()
if response.CancelRequested {
return []byte("Activity Cancelled."), true, nil
}
Expand Down Expand Up @@ -1241,7 +1269,9 @@ func (s *IntegrationSuite) TestActivityCancellationNotStarted() {
Identity: identity,
}

we, err0 := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err0 := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err0)

s.Logger.Info("StartWorkflowExecutionn", tag.WorkflowRunID(we.GetRunID()))
Expand Down Expand Up @@ -1315,7 +1345,9 @@ func (s *IntegrationSuite) TestActivityCancellationNotStarted() {
// Send signal so that worker can send an activity cancel
signalName := "my signal"
signalInput := []byte("my signal input.")
err = s.engine.SignalWorkflowExecution(createContext(), &types.SignalWorkflowExecutionRequest{
ctx, cancel = createContext()
defer cancel()
err = s.engine.SignalWorkflowExecution(ctx, &types.SignalWorkflowExecutionRequest{
Domain: s.domainName,
WorkflowExecution: &types.WorkflowExecution{
WorkflowID: id,
Expand Down
16 changes: 12 additions & 4 deletions host/archival_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ func (s *IntegrationSuite) TestVisibilityArchival() {
Query: fmt.Sprintf("CloseTime >= %v and CloseTime <= %v and WorkflowType = '%s'", startTime, endTime, workflowType),
}
for len(executions) == 0 || request.NextPageToken != nil {
response, err := s.engine.ListArchivedWorkflowExecutions(createContext(), request)
ctx, cancel := createContext()
response, err := s.engine.ListArchivedWorkflowExecutions(ctx, request)
cancel()
s.NoError(err)
s.NotNil(response)
executions = append(executions, response.GetExecutions()...)
Expand All @@ -148,7 +150,9 @@ func (s *IntegrationSuite) TestVisibilityArchival() {
}

func (s *IntegrationSuite) getDomainID(domain string) string {
domainResp, err := s.engine.DescribeDomain(createContext(), &types.DescribeDomainRequest{
ctx, cancel := createContext()
defer cancel()
domainResp, err := s.engine.DescribeDomain(ctx, &types.DescribeDomainRequest{
Name: common.StringPtr(s.archivalDomainName),
})
s.Nil(err)
Expand All @@ -162,7 +166,9 @@ func (s *IntegrationSuite) isHistoryArchived(domain string, execution *types.Wor
}

for i := 0; i < retryLimit; i++ {
getHistoryResp, err := s.engine.GetWorkflowExecutionHistory(createContext(), request)
ctx, cancel := createContext()
getHistoryResp, err := s.engine.GetWorkflowExecutionHistory(ctx, request)
cancel()
if err == nil && getHistoryResp != nil && getHistoryResp.GetArchived() {
return true
}
Expand Down Expand Up @@ -231,7 +237,9 @@ func (s *IntegrationSuite) startAndFinishWorkflow(id, wt, tl, domain, domainID s
TaskStartToCloseTimeoutSeconds: common.Int32Ptr(1),
Identity: identity,
}
we, err := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err)
s.Logger.Info("StartWorkflowExecution", tag.WorkflowRunID(we.RunID))
RunIDs := make([]string, numRuns)
Expand Down
40 changes: 30 additions & 10 deletions host/cancel_workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ func (s *IntegrationSuite) TestExternalRequestCancelWorkflowExecution() {
Identity: identity,
}

we, err0 := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err0 := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err0)

s.Logger.Info("StartWorkflowExecution", tag.WorkflowRunID(we.RunID))
Expand Down Expand Up @@ -119,7 +121,9 @@ func (s *IntegrationSuite) TestExternalRequestCancelWorkflowExecution() {
s.Logger.Info("PollAndProcessActivityTask", tag.Error(err))
s.Nil(err)

err = s.engine.RequestCancelWorkflowExecution(createContext(), &types.RequestCancelWorkflowExecutionRequest{
ctx, cancel = createContext()
defer cancel()
err = s.engine.RequestCancelWorkflowExecution(ctx, &types.RequestCancelWorkflowExecutionRequest{
Domain: s.domainName,
WorkflowExecution: &types.WorkflowExecution{
WorkflowID: id,
Expand All @@ -128,7 +132,9 @@ func (s *IntegrationSuite) TestExternalRequestCancelWorkflowExecution() {
})
s.Nil(err)

err = s.engine.RequestCancelWorkflowExecution(createContext(), &types.RequestCancelWorkflowExecutionRequest{
ctx, cancel = createContext()
defer cancel()
err = s.engine.RequestCancelWorkflowExecution(ctx, &types.RequestCancelWorkflowExecutionRequest{
Domain: s.domainName,
WorkflowExecution: &types.WorkflowExecution{
WorkflowID: id,
Expand All @@ -145,13 +151,15 @@ func (s *IntegrationSuite) TestExternalRequestCancelWorkflowExecution() {
executionCancelled := false
GetHistoryLoop:
for i := 1; i < 3; i++ {
historyResponse, err := s.engine.GetWorkflowExecutionHistory(createContext(), &types.GetWorkflowExecutionHistoryRequest{
ctx, cancel := createContext()
historyResponse, err := s.engine.GetWorkflowExecutionHistory(ctx, &types.GetWorkflowExecutionHistoryRequest{
Domain: s.domainName,
Execution: &types.WorkflowExecution{
WorkflowID: id,
RunID: we.RunID,
},
})
cancel()
s.Nil(err)
history := historyResponse.History

Expand Down Expand Up @@ -194,7 +202,9 @@ func (s *IntegrationSuite) TestRequestCancelWorkflowDecisionExecution() {
TaskStartToCloseTimeoutSeconds: common.Int32Ptr(1),
Identity: identity,
}
we, err0 := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err0 := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err0)
s.Logger.Info("StartWorkflowExecution", tag.WorkflowRunID(we.RunID))

Expand All @@ -209,7 +219,9 @@ func (s *IntegrationSuite) TestRequestCancelWorkflowDecisionExecution() {
TaskStartToCloseTimeoutSeconds: common.Int32Ptr(1),
Identity: identity,
}
we2, err0 := s.engine.StartWorkflowExecution(createContext(), foreignRequest)
ctx, cancel = createContext()
defer cancel()
we2, err0 := s.engine.StartWorkflowExecution(ctx, foreignRequest)
s.Nil(err0)
s.Logger.Info("StartWorkflowExecution on foreign Domain: %v, response: %v \n", tag.WorkflowDomainName(s.foreignDomainName), tag.WorkflowRunID(we2.RunID))

Expand Down Expand Up @@ -328,13 +340,15 @@ func (s *IntegrationSuite) TestRequestCancelWorkflowDecisionExecution() {
intiatedEventID := 10
CheckHistoryLoopForCancelSent:
for i := 1; i < 10; i++ {
historyResponse, err := s.engine.GetWorkflowExecutionHistory(createContext(), &types.GetWorkflowExecutionHistoryRequest{
ctx, cancel := createContext()
historyResponse, err := s.engine.GetWorkflowExecutionHistory(ctx, &types.GetWorkflowExecutionHistoryRequest{
Domain: s.domainName,
Execution: &types.WorkflowExecution{
WorkflowID: id,
RunID: we.RunID,
},
})
cancel()
s.Nil(err)
history := historyResponse.History

Expand Down Expand Up @@ -364,13 +378,15 @@ CheckHistoryLoopForCancelSent:
executionCancelled := false
GetHistoryLoop:
for i := 1; i < 10; i++ {
historyResponse, err := s.engine.GetWorkflowExecutionHistory(createContext(), &types.GetWorkflowExecutionHistoryRequest{
ctx, cancel := createContext()
historyResponse, err := s.engine.GetWorkflowExecutionHistory(ctx, &types.GetWorkflowExecutionHistoryRequest{
Domain: s.foreignDomainName,
Execution: &types.WorkflowExecution{
WorkflowID: id,
RunID: we2.RunID,
},
})
cancel()
s.Nil(err)
history := historyResponse.History

Expand Down Expand Up @@ -428,7 +444,9 @@ func (s *IntegrationSuite) TestRequestCancelWorkflowDecisionExecution_UnKnownTar
TaskStartToCloseTimeoutSeconds: common.Int32Ptr(1),
Identity: identity,
}
we, err0 := s.engine.StartWorkflowExecution(createContext(), request)
ctx, cancel := createContext()
defer cancel()
we, err0 := s.engine.StartWorkflowExecution(ctx, request)
s.Nil(err0)
s.Logger.Info("StartWorkflowExecution", tag.WorkflowRunID(we.RunID))

Expand Down Expand Up @@ -496,13 +514,15 @@ func (s *IntegrationSuite) TestRequestCancelWorkflowDecisionExecution_UnKnownTar
intiatedEventID := 10
CheckHistoryLoopForCancelSent:
for i := 1; i < 10; i++ {
historyResponse, err := s.engine.GetWorkflowExecutionHistory(createContext(), &types.GetWorkflowExecutionHistoryRequest{
ctx, cancel := createContext()
historyResponse, err := s.engine.GetWorkflowExecutionHistory(ctx, &types.GetWorkflowExecutionHistoryRequest{
Domain: s.domainName,
Execution: &types.WorkflowExecution{
WorkflowID: id,
RunID: we.RunID,
},
})
cancel()
s.Nil(err)
history := historyResponse.History

Expand Down
Loading

0 comments on commit 346def2

Please sign in to comment.