Skip to content

Commit

Permalink
Enforce persistence context timeout in application layer: Part 2 (cad…
Browse files Browse the repository at this point in the history
…ence-workflow#3590)

Enforce persistence context timeout in reconciliation package and execution scanner
  • Loading branch information
yycptt authored Feb 4, 2021
1 parent 98347b1 commit 9afea11
Show file tree
Hide file tree
Showing 30 changed files with 256 additions and 121 deletions.
6 changes: 4 additions & 2 deletions common/ndc/history_resender.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ func (n *HistoryResenderImpl) SendSingleWorkflowHistory(
// Case 1: the workflow pass the retention period
// Case 2: the workflow is corrupted
if skipTask := n.fixCurrentExecution(
ctx,
domainID,
workflowID,
runID,
Expand Down Expand Up @@ -305,6 +306,7 @@ func (n *HistoryResenderImpl) getHistory(
}

func (n *HistoryResenderImpl) fixCurrentExecution(
ctx context.Context,
domainID string,
workflowID string,
runID string,
Expand All @@ -320,7 +322,7 @@ func (n *HistoryResenderImpl) fixCurrentExecution(
State: persistence.WorkflowStateRunning,
},
}
res := n.currentExecutionCheck.Check(execution)
res := n.currentExecutionCheck.Check(ctx, execution)
switch res.CheckResultType {
case invariant.CheckResultTypeCorrupted:
n.logger.Error(
Expand All @@ -329,7 +331,7 @@ func (n *HistoryResenderImpl) fixCurrentExecution(
tag.WorkflowID(workflowID),
tag.WorkflowRunID(runID),
)
n.currentExecutionCheck.Fix(execution)
n.currentExecutionCheck.Fix(ctx, execution)
return false
case invariant.CheckResultTypeFailed:
return false
Expand Down
10 changes: 5 additions & 5 deletions common/ndc/history_resender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,17 @@ func (s *historyResenderSuite) TestCurrentExecutionCheck() {
State: persistence.WorkflowStateRunning,
},
}
invariantMock.EXPECT().Check(execution1).Return(invariant.CheckResult{
invariantMock.EXPECT().Check(gomock.Any(), execution1).Return(invariant.CheckResult{
CheckResultType: invariant.CheckResultTypeCorrupted,
}).Times(1)
invariantMock.EXPECT().Check(execution2).Return(invariant.CheckResult{
invariantMock.EXPECT().Check(gomock.Any(), execution2).Return(invariant.CheckResult{
CheckResultType: invariant.CheckResultTypeHealthy,
}).Times(1)
invariantMock.EXPECT().Fix(gomock.Any()).Return(invariant.FixResult{}).Times(1)
invariantMock.EXPECT().Fix(gomock.Any(), gomock.Any()).Return(invariant.FixResult{}).Times(1)

skipTask := s.rereplicator.fixCurrentExecution(domainID, workflowID1, runID)
skipTask := s.rereplicator.fixCurrentExecution(context.Background(), domainID, workflowID1, runID)
s.False(skipTask)
skipTask = s.rereplicator.fixCurrentExecution(domainID, workflowID2, runID)
skipTask = s.rereplicator.fixCurrentExecution(context.Background(), domainID, workflowID2, runID)
s.True(skipTask)
}

Expand Down
7 changes: 5 additions & 2 deletions common/pagination/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@

package pagination

import "errors"
import (
"context"
"errors"
)

// ErrIteratorFinished indicates that Next was called on a finished iterator
var ErrIteratorFinished = errors.New("iterator has reached end")
Expand Down Expand Up @@ -52,7 +55,7 @@ type (
ShouldFlushFn func(Page) bool
// FetchFn fetches Page from PageToken.
// Once a page with nil NextToken is returned no more pages will be fetched.
FetchFn func(PageToken) (Page, error)
FetchFn func(context.Context, PageToken) (Page, error)
)

type (
Expand Down
11 changes: 8 additions & 3 deletions common/pagination/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@

package pagination

import "context"

type (
iterator struct {
ctx context.Context

page Page
entityIndex int

Expand All @@ -36,6 +40,7 @@ type (

// NewIterator constructs a new Iterator
func NewIterator(
ctx context.Context,
startingPageToken PageToken,
fetchFn FetchFn,
) Iterator {
Expand All @@ -56,9 +61,9 @@ func NewIterator(
// Returning nil, nil is valid if that is what the provided fetch function provided.
func (i *iterator) Next() (Entity, error) {
entity := i.nextEntity
error := i.nextError
err := i.nextError
i.advance(false)
return entity, error
return entity, err
}

// HasNext returns true if there is a next element. There is considered to be a next element
Expand Down Expand Up @@ -86,7 +91,7 @@ func (i *iterator) advanceToNonEmptyPage(firstPage bool) error {
if i.page.NextToken == nil && !firstPage {
return ErrIteratorFinished
}
nextPage, err := i.fetchFn(i.page.NextToken)
nextPage, err := i.fetchFn(i.ctx, i.page.NextToken)
if err != nil {
return err
}
Expand Down
13 changes: 7 additions & 6 deletions common/pagination/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
package pagination

import (
"context"
"errors"
"testing"

Expand Down Expand Up @@ -52,7 +53,7 @@ func (s *IteratorSuite) SetupTest() {
}

func (s *IteratorSuite) TestInitializedToEmpty() {
fetchFn := func(token PageToken) (Page, error) {
fetchFn := func(_ context.Context, token PageToken) (Page, error) {
if token.(int) == 2 {
return Page{
CurrentToken: token,
Expand All @@ -66,14 +67,14 @@ func (s *IteratorSuite) TestInitializedToEmpty() {
Entities: fetchMap[token],
}, nil
}
itr := NewIterator(0, fetchFn)
itr := NewIterator(context.Background(), 0, fetchFn)
s.False(itr.HasNext())
_, err := itr.Next()
s.Equal(ErrIteratorFinished, err)
}

func (s *IteratorSuite) TestNonEmptyNoErrors() {
fetchFn := func(token PageToken) (Page, error) {
fetchFn := func(_ context.Context, token PageToken) (Page, error) {
var nextPageToken interface{} = token.(int) + 1
if nextPageToken.(int) == 5 {
nextPageToken = nil
Expand All @@ -84,7 +85,7 @@ func (s *IteratorSuite) TestNonEmptyNoErrors() {
Entities: fetchMap[token],
}, nil
}
itr := NewIterator(0, fetchFn)
itr := NewIterator(context.Background(), 0, fetchFn)
expectedResults := []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}
i := 0
for itr.HasNext() {
Expand All @@ -99,7 +100,7 @@ func (s *IteratorSuite) TestNonEmptyNoErrors() {
}

func (s *IteratorSuite) TestNonEmptyWithErrors() {
fetchFn := func(token PageToken) (Page, error) {
fetchFn := func(_ context.Context, token PageToken) (Page, error) {
if token.(int) == 4 {
return Page{}, errors.New("got error")
}
Expand All @@ -109,7 +110,7 @@ func (s *IteratorSuite) TestNonEmptyWithErrors() {
Entities: fetchMap[token],
}, nil
}
itr := NewIterator(0, fetchFn)
itr := NewIterator(context.Background(), 0, fetchFn)
expectedResults := []string{"one", "two", "three", "four", "five", "six", "seven"}
i := 0
for itr.HasNext() {
Expand Down
5 changes: 3 additions & 2 deletions common/pagination/writerIterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package pagination

import (
"bytes"
"context"
"encoding/json"
"fmt"
"testing"
Expand Down Expand Up @@ -86,7 +87,7 @@ func (s *WriterIteratorSuite) TestWriterIterator() {
s.Equal(expectedKey, flushedKeys[i].(string))
}

fetchFn := func(token PageToken) (Page, error) {
fetchFn := func(_ context.Context, token PageToken) (Page, error) {
key := flushedKeys[token.(int)]
data := store[key.(string)]
dataBlobs := bytes.Split(data, separator)
Expand All @@ -112,7 +113,7 @@ func (s *WriterIteratorSuite) TestWriterIterator() {
}, nil
}

itr := NewIterator(0, fetchFn)
itr := NewIterator(context.Background(), 0, fetchFn)
itrCount := 0
for itr.HasNext() {
val, err := itr.Next()
Expand Down
20 changes: 14 additions & 6 deletions common/reconciliation/fetcher/concrete.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,20 @@ import (
)

// ConcreteExecutionIterator is used to retrieve Concrete executions.
func ConcreteExecutionIterator(retryer persistence.Retryer, pageSize int) pagination.Iterator {
return pagination.NewIterator(nil, getConcreteExecutions(retryer, pageSize, codec.NewThriftRWEncoder()))
func ConcreteExecutionIterator(
ctx context.Context,
retryer persistence.Retryer,
pageSize int,
) pagination.Iterator {
return pagination.NewIterator(ctx, nil, getConcreteExecutions(retryer, pageSize, codec.NewThriftRWEncoder()))
}

// ConcreteExecution returns a single ConcreteExecution from persistence
func ConcreteExecution(retryer persistence.Retryer, request ExecutionRequest) (entity.Entity, error) {
func ConcreteExecution(
ctx context.Context,
retryer persistence.Retryer,
request ExecutionRequest,
) (entity.Entity, error) {

req := persistence.GetWorkflowExecutionRequest{
DomainID: request.DomainID,
Expand All @@ -48,7 +56,7 @@ func ConcreteExecution(retryer persistence.Retryer, request ExecutionRequest) (e
RunId: common.StringPtr(request.RunID),
},
}
e, err := retryer.GetWorkflowExecution(context.TODO(), &req)
e, err := retryer.GetWorkflowExecution(ctx, &req)
if err != nil {
return nil, err
}
Expand All @@ -74,14 +82,14 @@ func getConcreteExecutions(
pageSize int,
encoder *codec.ThriftRWEncoder,
) pagination.FetchFn {
return func(token pagination.PageToken) (pagination.Page, error) {
return func(ctx context.Context, token pagination.PageToken) (pagination.Page, error) {
req := &persistence.ListConcreteExecutionsRequest{
PageSize: pageSize,
}
if token != nil {
req.PageToken = token.([]byte)
}
resp, err := pr.ListConcreteExecutions(context.TODO(), req)
resp, err := pr.ListConcreteExecutions(ctx, req)
if err != nil {
return pagination.Page{}, err
}
Expand Down
20 changes: 14 additions & 6 deletions common/reconciliation/fetcher/current.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,25 @@ import (
)

// CurrentExecutionIterator is used to retrieve Concrete executions.
func CurrentExecutionIterator(retryer persistence.Retryer, pageSize int) pagination.Iterator {
return pagination.NewIterator(nil, getCurrentExecution(retryer, pageSize))
func CurrentExecutionIterator(
ctx context.Context,
retryer persistence.Retryer,
pageSize int,
) pagination.Iterator {
return pagination.NewIterator(ctx, nil, getCurrentExecution(retryer, pageSize))
}

// CurrentExecution returns a single execution
func CurrentExecution(retryer persistence.Retryer, request ExecutionRequest) (entity.Entity, error) {
func CurrentExecution(
ctx context.Context,
retryer persistence.Retryer,
request ExecutionRequest,
) (entity.Entity, error) {
req := persistence.GetCurrentExecutionRequest{
DomainID: request.DomainID,
WorkflowID: request.WorkflowID,
}
e, err := retryer.GetCurrentExecution(context.TODO(), &req)
e, err := retryer.GetCurrentExecution(ctx, &req)
if err != nil {
return nil, err
}
Expand All @@ -62,14 +70,14 @@ func getCurrentExecution(
pr persistence.Retryer,
pageSize int,
) pagination.FetchFn {
return func(token pagination.PageToken) (pagination.Page, error) {
return func(ctx context.Context, token pagination.PageToken) (pagination.Page, error) {
req := &persistence.ListCurrentExecutionsRequest{
PageSize: pageSize,
}
if token != nil {
req.PageToken = token.([]byte)
}
resp, err := pr.ListCurrentExecutions(context.TODO(), req)
resp, err := pr.ListCurrentExecutions(ctx, req)
if err != nil {
return pagination.Page{}, err
}
Expand Down
24 changes: 17 additions & 7 deletions common/reconciliation/invariant/concreteExecutionExists.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ func NewConcreteExecutionExists(
}

func (c *concreteExecutionExists) Check(
ctx context.Context,
execution interface{},
) CheckResult {
if checkResult := validateCheckContext(ctx, c.Name()); checkResult != nil {
return *checkResult
}

currentExecution, ok := execution.(*entity.CurrentExecution)
if !ok {
Expand All @@ -62,13 +66,13 @@ func (c *concreteExecutionExists) Check(
if len(currentExecution.CurrentRunID) == 0 {
// set the current run id
var runIDCheckResult *CheckResult
currentExecution, runIDCheckResult = c.validateCurrentRunID(currentExecution)
currentExecution, runIDCheckResult = c.validateCurrentRunID(ctx, currentExecution)
if runIDCheckResult != nil {
return *runIDCheckResult
}
}

concreteExecResp, concreteExecErr := c.pr.IsWorkflowExecutionExists(context.TODO(), &persistence.IsWorkflowExecutionExistsRequest{
concreteExecResp, concreteExecErr := c.pr.IsWorkflowExecutionExists(ctx, &persistence.IsWorkflowExecutionExistsRequest{
DomainID: currentExecution.DomainID,
WorkflowID: currentExecution.WorkflowID,
RunID: currentExecution.CurrentRunID,
Expand All @@ -83,7 +87,7 @@ func (c *concreteExecutionExists) Check(
}
if !concreteExecResp.Exists {
//verify if the current execution exists
_, checkResult := c.validateCurrentRunID(currentExecution)
_, checkResult := c.validateCurrentRunID(ctx, currentExecution)
if checkResult != nil {
return *checkResult
}
Expand All @@ -102,14 +106,18 @@ func (c *concreteExecutionExists) Check(
}

func (c *concreteExecutionExists) Fix(
ctx context.Context,
execution interface{},
) FixResult {
if fixResult := validateFixContext(ctx, c.Name()); fixResult != nil {
return *fixResult
}

currentExecution, _ := execution.(*entity.CurrentExecution)
var runIDCheckResult *CheckResult
if len(currentExecution.CurrentRunID) == 0 {
// this is to set the current run ID prior to the check and fix operations
currentExecution, runIDCheckResult = c.validateCurrentRunID(currentExecution)
currentExecution, runIDCheckResult = c.validateCurrentRunID(ctx, currentExecution)
if runIDCheckResult != nil {
return FixResult{
FixResultType: FixResultTypeSkipped,
Expand All @@ -118,17 +126,18 @@ func (c *concreteExecutionExists) Fix(
}
}
}
fixResult, checkResult := checkBeforeFix(c, currentExecution)
fixResult, checkResult := checkBeforeFix(ctx, c, currentExecution)
if fixResult != nil {
return *fixResult
}
if err := c.pr.DeleteCurrentWorkflowExecution(context.TODO(), &persistence.DeleteCurrentWorkflowExecutionRequest{
if err := c.pr.DeleteCurrentWorkflowExecution(ctx, &persistence.DeleteCurrentWorkflowExecutionRequest{
DomainID: currentExecution.DomainID,
WorkflowID: currentExecution.WorkflowID,
RunID: currentExecution.CurrentRunID,
}); err != nil {
return FixResult{
FixResultType: FixResultTypeFailed,
InvariantName: c.Name(),
Info: "failed to delete current workflow execution",
InfoDetails: err.Error(),
}
Expand All @@ -145,10 +154,11 @@ func (c *concreteExecutionExists) Name() Name {
}

func (c *concreteExecutionExists) validateCurrentRunID(
ctx context.Context,
currentExecution *entity.CurrentExecution,
) (*entity.CurrentExecution, *CheckResult) {

resp, err := c.pr.GetCurrentExecution(context.TODO(), &persistence.GetCurrentExecutionRequest{
resp, err := c.pr.GetCurrentExecution(ctx, &persistence.GetCurrentExecutionRequest{
DomainID: currentExecution.DomainID,
WorkflowID: currentExecution.WorkflowID,
})
Expand Down
Loading

0 comments on commit 9afea11

Please sign in to comment.