Skip to content

Commit

Permalink
StartWorkflowExecution: validate RequestID before calling history (ub…
Browse files Browse the repository at this point in the history
…er#5359)

What changed?
Validating RequestID (uuid) at handler level.

Why?
When calling StartWorkflowExecution and passing wrongly formatted UUID, a generic error will be returned from persistence layer (cassandra for example). This error is not treated as non-retryable, so Cadence will try to insert wrong data multiple times. On the client side, only request-timeout will be returned which reveals no details about the nature for this failure.

This change will validate UUID on handler side and no calls to history/persistence will be made. Additionally, user will get information on what data is missing or malformed.

How did you test it?
Unit test updated to include malformed UUID check
  • Loading branch information
mantas-sidlauskas committed Aug 4, 2023
1 parent 1202a2e commit 4117f96
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 35 deletions.
65 changes: 32 additions & 33 deletions service/frontend/workflowHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
"sync/atomic"
"time"

"github.com/pborman/uuid"
"github.com/google/uuid"
"go.uber.org/yarpc"
"go.uber.org/yarpc/yarpcerrors"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -131,7 +131,6 @@ var (
errQueryTypeNotSet = &types.BadRequestError{Message: "QueryType is not set on request."}
errRequestNotSet = &types.BadRequestError{Message: "Request is nil."}
errNoPermission = &types.BadRequestError{Message: "No permission to do this operation."}
errRequestIDNotSet = &types.BadRequestError{Message: "RequestId is not set on request."}
errWorkflowTypeNotSet = &types.BadRequestError{Message: "WorkflowType is not set on request."}
errInvalidRetention = &types.BadRequestError{Message: "RetentionDays is invalid."}
errInvalidExecutionStartToCloseTimeoutSeconds = &types.BadRequestError{Message: "A valid ExecutionStartToCloseTimeoutSeconds is not set on request."}
Expand Down Expand Up @@ -615,7 +614,7 @@ func (wh *WorkflowHandler) PollForActivityTask(
); err != nil {
return &types.PollForActivityTaskResponse{}, nil
}
pollerID := uuid.New()
pollerID := uuid.New().String()
op := func() error {
resp, err = wh.GetMatchingClient().PollForActivityTask(ctx, &types.MatchingPollForActivityTaskRequest{
DomainUUID: domainID,
Expand Down Expand Up @@ -745,7 +744,7 @@ func (wh *WorkflowHandler) PollForDecisionTask(
return &types.PollForDecisionTaskResponse{}, nil
}

pollerID := uuid.New()
pollerID := uuid.New().String()
var matchingResp *types.MatchingPollForDecisionTaskResponse
op := func() error {
matchingResp, err = wh.GetMatchingClient().PollForDecisionTask(ctx, &types.MatchingPollForDecisionTaskRequest{
Expand Down Expand Up @@ -2082,18 +2081,14 @@ func (wh *WorkflowHandler) StartWorkflowExecution(
scope, sw := wh.startRequestProfileWithDomain(ctx, metrics.FrontendStartWorkflowExecutionScope, startRequest)
defer sw.Stop()

if wh.isShuttingDown() {
return nil, errShuttingDown
}

if err := wh.versionChecker.ClientSupported(ctx, wh.config.EnableClientVersionCheck()); err != nil {
return nil, wh.error(err, scope)
}

if startRequest == nil {
return nil, wh.error(errRequestNotSet, scope)
}

if wh.isShuttingDown() {
return nil, errShuttingDown
}

domainName := startRequest.GetDomain()
wfExecution := &types.WorkflowExecution{
WorkflowID: startRequest.GetWorkflowID(),
Expand All @@ -2104,6 +2099,21 @@ func (wh *WorkflowHandler) StartWorkflowExecution(
return nil, wh.error(errDomainNotSet, scope, tags...)
}

if startRequest.GetWorkflowID() == "" {
return nil, wh.error(errWorkflowIDNotSet, scope, tags...)
}

if _, err := uuid.Parse(startRequest.RequestID); err != nil {
return nil, wh.error(&types.BadRequestError{Message: fmt.Sprintf("requestId %q is not a valid UUID", startRequest.RequestID)}, scope, tags...)
}
if startRequest.WorkflowType == nil || startRequest.WorkflowType.GetName() == "" {
return nil, wh.error(errWorkflowTypeNotSet, scope, tags...)
}

if err := wh.versionChecker.ClientSupported(ctx, wh.config.EnableClientVersionCheck()); err != nil {
return nil, wh.error(err, scope)
}

if ok := wh.allow(ratelimitTypeUser, startRequest); !ok {
return nil, wh.error(createServiceBusyError(), scope, tags...)
}
Expand All @@ -2121,10 +2131,6 @@ func (wh *WorkflowHandler) StartWorkflowExecution(
return nil, wh.error(errDomainTooLong, scope, tags...)
}

if startRequest.GetWorkflowID() == "" {
return nil, wh.error(errWorkflowIDNotSet, scope, tags...)
}

if !common.ValidIDLength(
startRequest.GetWorkflowID(),
scope,
Expand All @@ -2141,20 +2147,10 @@ func (wh *WorkflowHandler) StartWorkflowExecution(
return nil, wh.error(err, scope, tags...)
}

if startRequest.GetCronSchedule() != "" {
if _, err := backoff.ValidateSchedule(startRequest.GetCronSchedule()); err != nil {
return nil, wh.error(err, scope, tags...)
}
}

wh.GetLogger().Debug(
"Received StartWorkflowExecution. WorkflowID",
tag.WorkflowID(startRequest.GetWorkflowID()))

if startRequest.WorkflowType == nil || startRequest.WorkflowType.GetName() == "" {
return nil, wh.error(errWorkflowTypeNotSet, scope, tags...)
}

if !common.ValidIDLength(
startRequest.WorkflowType.GetName(),
scope,
Expand Down Expand Up @@ -2189,6 +2185,11 @@ func (wh *WorkflowHandler) StartWorkflowExecution(

jitter := startRequest.GetJitterStartSeconds()
cron := startRequest.GetCronSchedule()
if cron != "" {
if _, err := backoff.ValidateSchedule(startRequest.GetCronSchedule()); err != nil {
return nil, wh.error(err, scope, tags...)
}
}
if jitter > 0 && cron != "" {
// Calculate the cron duration and ensure that jitter is not greater than the cron duration,
// because that would be confusing to users.
Expand All @@ -2205,10 +2206,6 @@ func (wh *WorkflowHandler) StartWorkflowExecution(
}
}

if startRequest.GetRequestID() == "" {
return nil, wh.error(errRequestIDNotSet, scope, tags...)
}

if !common.ValidIDLength(
startRequest.GetRequestID(),
scope,
Expand Down Expand Up @@ -4264,8 +4261,10 @@ func validateExecution(w *types.WorkflowExecution) error {
if w.GetWorkflowID() == "" {
return errWorkflowIDNotSet
}
if w.GetRunID() != "" && uuid.Parse(w.GetRunID()) == nil {
return errInvalidRunID
if w.GetRunID() != "" {
if _, err := uuid.Parse(w.GetRunID()); err != nil {
return errInvalidRunID
}
}
return nil
}
Expand Down Expand Up @@ -4715,7 +4714,7 @@ func (wh *WorkflowHandler) normalizeVersionedErrors(ctx context.Context, err err
func constructRestartWorkflowRequest(w *types.WorkflowExecutionStartedEventAttributes, domain string, identity string, workflowID string) *types.StartWorkflowExecutionRequest {

startRequest := &types.StartWorkflowExecutionRequest{
RequestID: uuid.New(),
RequestID: uuid.New().String(),
Domain: domain,
WorkflowID: workflowID,
WorkflowType: &types.WorkflowType{
Expand Down
9 changes: 7 additions & 2 deletions service/frontend/workflowHandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,12 @@ func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_RequestIdNotSet
}
_, err := wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest)
s.Error(err)
s.Equal(errRequestIDNotSet, err)
s.Equal(&types.BadRequestError{Message: "requestId \"\" is not a valid UUID"}, err)
startWorkflowExecutionRequest.RequestID = "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
_, err = wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest)
s.Error(err)
s.Equal(&types.BadRequestError{Message: "requestId \"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx\" is not a valid UUID"}, err)

}

func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_BadDelayStartSeconds() {
Expand Down Expand Up @@ -1389,7 +1394,7 @@ func (s *workflowHandlerSuite) TestRestartWorkflowExecution__Success() {
},
Identity: "",
})
s.Equal(resp.GetRunID(), testRunID)
s.Equal(testRunID, resp.GetRunID())
s.NoError(err)
}

Expand Down

0 comments on commit 4117f96

Please sign in to comment.