Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce persistence context timeout in application layer: Part 2 #3590

Merged
merged 3 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions common/ndc/history_resender.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ func (n *HistoryResenderImpl) SendSingleWorkflowHistory(
// Case 1: the workflow pass the retention period
// Case 2: the workflow is corrupted
if skipTask := n.fixCurrentExecution(
ctx,
domainID,
workflowID,
runID,
Expand Down Expand Up @@ -305,6 +306,7 @@ func (n *HistoryResenderImpl) getHistory(
}

func (n *HistoryResenderImpl) fixCurrentExecution(
ctx context.Context,
domainID string,
workflowID string,
runID string,
Expand All @@ -320,7 +322,7 @@ func (n *HistoryResenderImpl) fixCurrentExecution(
State: persistence.WorkflowStateRunning,
},
}
res := n.currentExecutionCheck.Check(execution)
res := n.currentExecutionCheck.Check(ctx, execution)
switch res.CheckResultType {
case invariant.CheckResultTypeCorrupted:
n.logger.Error(
Expand All @@ -329,7 +331,7 @@ func (n *HistoryResenderImpl) fixCurrentExecution(
tag.WorkflowID(workflowID),
tag.WorkflowRunID(runID),
)
n.currentExecutionCheck.Fix(execution)
n.currentExecutionCheck.Fix(ctx, execution)
return false
case invariant.CheckResultTypeFailed:
return false
Expand Down
10 changes: 5 additions & 5 deletions common/ndc/history_resender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,17 @@ func (s *historyResenderSuite) TestCurrentExecutionCheck() {
State: persistence.WorkflowStateRunning,
},
}
invariantMock.EXPECT().Check(execution1).Return(invariant.CheckResult{
invariantMock.EXPECT().Check(gomock.Any(), execution1).Return(invariant.CheckResult{
CheckResultType: invariant.CheckResultTypeCorrupted,
}).Times(1)
invariantMock.EXPECT().Check(execution2).Return(invariant.CheckResult{
invariantMock.EXPECT().Check(gomock.Any(), execution2).Return(invariant.CheckResult{
CheckResultType: invariant.CheckResultTypeHealthy,
}).Times(1)
invariantMock.EXPECT().Fix(gomock.Any()).Return(invariant.FixResult{}).Times(1)
invariantMock.EXPECT().Fix(gomock.Any(), gomock.Any()).Return(invariant.FixResult{}).Times(1)

skipTask := s.rereplicator.fixCurrentExecution(domainID, workflowID1, runID)
skipTask := s.rereplicator.fixCurrentExecution(context.Background(), domainID, workflowID1, runID)
s.False(skipTask)
skipTask = s.rereplicator.fixCurrentExecution(domainID, workflowID2, runID)
skipTask = s.rereplicator.fixCurrentExecution(context.Background(), domainID, workflowID2, runID)
s.True(skipTask)
}

Expand Down
7 changes: 5 additions & 2 deletions common/pagination/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@

package pagination

import "errors"
import (
"context"
"errors"
)

// ErrIteratorFinished indicates that Next was called on a finished iterator
var ErrIteratorFinished = errors.New("iterator has reached end")
Expand Down Expand Up @@ -52,7 +55,7 @@ type (
ShouldFlushFn func(Page) bool
// FetchFn fetches Page from PageToken.
// Once a page with nil NextToken is returned no more pages will be fetched.
FetchFn func(PageToken) (Page, error)
FetchFn func(context.Context, PageToken) (Page, error)
)

type (
Expand Down
11 changes: 8 additions & 3 deletions common/pagination/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@

package pagination

import "context"

type (
iterator struct {
ctx context.Context

page Page
entityIndex int

Expand All @@ -36,6 +40,7 @@ type (

// NewIterator constructs a new Iterator
func NewIterator(
ctx context.Context,
startingPageToken PageToken,
fetchFn FetchFn,
) Iterator {
Expand All @@ -56,9 +61,9 @@ func NewIterator(
// Returning nil, nil is valid if that is what the provided fetch function provided.
func (i *iterator) Next() (Entity, error) {
entity := i.nextEntity
error := i.nextError
err := i.nextError
i.advance(false)
return entity, error
return entity, err
}

// HasNext returns true if there is a next element. There is considered to be a next element
Expand Down Expand Up @@ -86,7 +91,7 @@ func (i *iterator) advanceToNonEmptyPage(firstPage bool) error {
if i.page.NextToken == nil && !firstPage {
return ErrIteratorFinished
}
nextPage, err := i.fetchFn(i.page.NextToken)
nextPage, err := i.fetchFn(i.ctx, i.page.NextToken)
if err != nil {
return err
}
Expand Down
13 changes: 7 additions & 6 deletions common/pagination/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
package pagination

import (
"context"
"errors"
"testing"

Expand Down Expand Up @@ -52,7 +53,7 @@ func (s *IteratorSuite) SetupTest() {
}

func (s *IteratorSuite) TestInitializedToEmpty() {
fetchFn := func(token PageToken) (Page, error) {
fetchFn := func(_ context.Context, token PageToken) (Page, error) {
if token.(int) == 2 {
return Page{
CurrentToken: token,
Expand All @@ -66,14 +67,14 @@ func (s *IteratorSuite) TestInitializedToEmpty() {
Entities: fetchMap[token],
}, nil
}
itr := NewIterator(0, fetchFn)
itr := NewIterator(context.Background(), 0, fetchFn)
s.False(itr.HasNext())
_, err := itr.Next()
s.Equal(ErrIteratorFinished, err)
}

func (s *IteratorSuite) TestNonEmptyNoErrors() {
fetchFn := func(token PageToken) (Page, error) {
fetchFn := func(_ context.Context, token PageToken) (Page, error) {
var nextPageToken interface{} = token.(int) + 1
if nextPageToken.(int) == 5 {
nextPageToken = nil
Expand All @@ -84,7 +85,7 @@ func (s *IteratorSuite) TestNonEmptyNoErrors() {
Entities: fetchMap[token],
}, nil
}
itr := NewIterator(0, fetchFn)
itr := NewIterator(context.Background(), 0, fetchFn)
expectedResults := []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}
i := 0
for itr.HasNext() {
Expand All @@ -99,7 +100,7 @@ func (s *IteratorSuite) TestNonEmptyNoErrors() {
}

func (s *IteratorSuite) TestNonEmptyWithErrors() {
fetchFn := func(token PageToken) (Page, error) {
fetchFn := func(_ context.Context, token PageToken) (Page, error) {
if token.(int) == 4 {
return Page{}, errors.New("got error")
}
Expand All @@ -109,7 +110,7 @@ func (s *IteratorSuite) TestNonEmptyWithErrors() {
Entities: fetchMap[token],
}, nil
}
itr := NewIterator(0, fetchFn)
itr := NewIterator(context.Background(), 0, fetchFn)
expectedResults := []string{"one", "two", "three", "four", "five", "six", "seven"}
i := 0
for itr.HasNext() {
Expand Down
5 changes: 3 additions & 2 deletions common/pagination/writerIterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package pagination

import (
"bytes"
"context"
"encoding/json"
"fmt"
"testing"
Expand Down Expand Up @@ -86,7 +87,7 @@ func (s *WriterIteratorSuite) TestWriterIterator() {
s.Equal(expectedKey, flushedKeys[i].(string))
}

fetchFn := func(token PageToken) (Page, error) {
fetchFn := func(_ context.Context, token PageToken) (Page, error) {
key := flushedKeys[token.(int)]
data := store[key.(string)]
dataBlobs := bytes.Split(data, separator)
Expand All @@ -112,7 +113,7 @@ func (s *WriterIteratorSuite) TestWriterIterator() {
}, nil
}

itr := NewIterator(0, fetchFn)
itr := NewIterator(context.Background(), 0, fetchFn)
itrCount := 0
for itr.HasNext() {
val, err := itr.Next()
Expand Down
20 changes: 14 additions & 6 deletions common/reconciliation/fetcher/concrete.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,20 @@ import (
)

// ConcreteExecutionIterator is used to retrieve Concrete executions.
func ConcreteExecutionIterator(retryer persistence.Retryer, pageSize int) pagination.Iterator {
return pagination.NewIterator(nil, getConcreteExecutions(retryer, pageSize, codec.NewThriftRWEncoder()))
func ConcreteExecutionIterator(
ctx context.Context,
retryer persistence.Retryer,
pageSize int,
) pagination.Iterator {
return pagination.NewIterator(ctx, nil, getConcreteExecutions(retryer, pageSize, codec.NewThriftRWEncoder()))
}

// ConcreteExecution returns a single ConcreteExecution from persistence
func ConcreteExecution(retryer persistence.Retryer, request ExecutionRequest) (entity.Entity, error) {
func ConcreteExecution(
ctx context.Context,
retryer persistence.Retryer,
request ExecutionRequest,
) (entity.Entity, error) {

req := persistence.GetWorkflowExecutionRequest{
DomainID: request.DomainID,
Expand All @@ -48,7 +56,7 @@ func ConcreteExecution(retryer persistence.Retryer, request ExecutionRequest) (e
RunId: common.StringPtr(request.RunID),
},
}
e, err := retryer.GetWorkflowExecution(context.TODO(), &req)
e, err := retryer.GetWorkflowExecution(ctx, &req)
if err != nil {
return nil, err
}
Expand All @@ -74,14 +82,14 @@ func getConcreteExecutions(
pageSize int,
encoder *codec.ThriftRWEncoder,
) pagination.FetchFn {
return func(token pagination.PageToken) (pagination.Page, error) {
return func(ctx context.Context, token pagination.PageToken) (pagination.Page, error) {
req := &persistence.ListConcreteExecutionsRequest{
PageSize: pageSize,
}
if token != nil {
req.PageToken = token.([]byte)
}
resp, err := pr.ListConcreteExecutions(context.TODO(), req)
resp, err := pr.ListConcreteExecutions(ctx, req)
if err != nil {
return pagination.Page{}, err
}
Expand Down
20 changes: 14 additions & 6 deletions common/reconciliation/fetcher/current.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,25 @@ import (
)

// CurrentExecutionIterator is used to retrieve Concrete executions.
func CurrentExecutionIterator(retryer persistence.Retryer, pageSize int) pagination.Iterator {
return pagination.NewIterator(nil, getCurrentExecution(retryer, pageSize))
func CurrentExecutionIterator(
ctx context.Context,
retryer persistence.Retryer,
pageSize int,
) pagination.Iterator {
return pagination.NewIterator(ctx, nil, getCurrentExecution(retryer, pageSize))
}

// CurrentExecution returns a single execution
func CurrentExecution(retryer persistence.Retryer, request ExecutionRequest) (entity.Entity, error) {
func CurrentExecution(
ctx context.Context,
retryer persistence.Retryer,
request ExecutionRequest,
) (entity.Entity, error) {
req := persistence.GetCurrentExecutionRequest{
DomainID: request.DomainID,
WorkflowID: request.WorkflowID,
}
e, err := retryer.GetCurrentExecution(context.TODO(), &req)
e, err := retryer.GetCurrentExecution(ctx, &req)
if err != nil {
return nil, err
}
Expand All @@ -62,14 +70,14 @@ func getCurrentExecution(
pr persistence.Retryer,
pageSize int,
) pagination.FetchFn {
return func(token pagination.PageToken) (pagination.Page, error) {
return func(ctx context.Context, token pagination.PageToken) (pagination.Page, error) {
req := &persistence.ListCurrentExecutionsRequest{
PageSize: pageSize,
}
if token != nil {
req.PageToken = token.([]byte)
}
resp, err := pr.ListCurrentExecutions(context.TODO(), req)
resp, err := pr.ListCurrentExecutions(ctx, req)
if err != nil {
return pagination.Page{}, err
}
Expand Down
24 changes: 17 additions & 7 deletions common/reconciliation/invariant/concreteExecutionExists.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ func NewConcreteExecutionExists(
}

func (c *concreteExecutionExists) Check(
ctx context.Context,
execution interface{},
) CheckResult {
if checkResult := validateCheckContext(ctx, c.Name()); checkResult != nil {
return *checkResult
}

currentExecution, ok := execution.(*entity.CurrentExecution)
if !ok {
Expand All @@ -62,13 +66,13 @@ func (c *concreteExecutionExists) Check(
if len(currentExecution.CurrentRunID) == 0 {
// set the current run id
var runIDCheckResult *CheckResult
currentExecution, runIDCheckResult = c.validateCurrentRunID(currentExecution)
currentExecution, runIDCheckResult = c.validateCurrentRunID(ctx, currentExecution)
if runIDCheckResult != nil {
return *runIDCheckResult
}
}

concreteExecResp, concreteExecErr := c.pr.IsWorkflowExecutionExists(context.TODO(), &persistence.IsWorkflowExecutionExistsRequest{
concreteExecResp, concreteExecErr := c.pr.IsWorkflowExecutionExists(ctx, &persistence.IsWorkflowExecutionExistsRequest{
DomainID: currentExecution.DomainID,
WorkflowID: currentExecution.WorkflowID,
RunID: currentExecution.CurrentRunID,
Expand All @@ -83,7 +87,7 @@ func (c *concreteExecutionExists) Check(
}
if !concreteExecResp.Exists {
//verify if the current execution exists
_, checkResult := c.validateCurrentRunID(currentExecution)
_, checkResult := c.validateCurrentRunID(ctx, currentExecution)
if checkResult != nil {
return *checkResult
}
Expand All @@ -102,14 +106,18 @@ func (c *concreteExecutionExists) Check(
}

func (c *concreteExecutionExists) Fix(
ctx context.Context,
execution interface{},
) FixResult {
if fixResult := validateFixContext(ctx, c.Name()); fixResult != nil {
return *fixResult
}

currentExecution, _ := execution.(*entity.CurrentExecution)
var runIDCheckResult *CheckResult
if len(currentExecution.CurrentRunID) == 0 {
// this is to set the current run ID prior to the check and fix operations
currentExecution, runIDCheckResult = c.validateCurrentRunID(currentExecution)
currentExecution, runIDCheckResult = c.validateCurrentRunID(ctx, currentExecution)
if runIDCheckResult != nil {
return FixResult{
FixResultType: FixResultTypeSkipped,
Expand All @@ -118,17 +126,18 @@ func (c *concreteExecutionExists) Fix(
}
}
}
fixResult, checkResult := checkBeforeFix(c, currentExecution)
fixResult, checkResult := checkBeforeFix(ctx, c, currentExecution)
if fixResult != nil {
return *fixResult
}
if err := c.pr.DeleteCurrentWorkflowExecution(context.TODO(), &persistence.DeleteCurrentWorkflowExecutionRequest{
if err := c.pr.DeleteCurrentWorkflowExecution(ctx, &persistence.DeleteCurrentWorkflowExecutionRequest{
DomainID: currentExecution.DomainID,
WorkflowID: currentExecution.WorkflowID,
RunID: currentExecution.CurrentRunID,
}); err != nil {
return FixResult{
FixResultType: FixResultTypeFailed,
InvariantName: c.Name(),
Info: "failed to delete current workflow execution",
InfoDetails: err.Error(),
}
Expand All @@ -145,10 +154,11 @@ func (c *concreteExecutionExists) Name() Name {
}

func (c *concreteExecutionExists) validateCurrentRunID(
ctx context.Context,
currentExecution *entity.CurrentExecution,
) (*entity.CurrentExecution, *CheckResult) {

resp, err := c.pr.GetCurrentExecution(context.TODO(), &persistence.GetCurrentExecutionRequest{
resp, err := c.pr.GetCurrentExecution(ctx, &persistence.GetCurrentExecutionRequest{
DomainID: currentExecution.DomainID,
WorkflowID: currentExecution.WorkflowID,
})
Expand Down
Loading