Skip to content

Commit

Permalink
Support cancelling a run from a JobProgressRef (#1663)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcastorina authored Aug 25, 2023
1 parent 33eed42 commit 5eb776c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 1 deletion.
19 changes: 19 additions & 0 deletions pkg/sources/job_progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type JobProgressHook interface {
}

// JobProgressRef is a wrapper of a JobProgress for read-only access to its state.
// If the job supports it, the reference can also be used to cancel running via
// CancelRun.
type JobProgressRef struct {
JobID int64
SourceID int64
Expand All @@ -66,6 +68,16 @@ func (r *JobProgressRef) Done() <-chan struct{} {
return r.jobProgress.Done()
}

// CancelRun requests that the job this is referencing is cancelled and stops
// running. This method will have no effect if the job does not allow
// cancellation.
func (r *JobProgressRef) CancelRun() {
if r.jobProgress == nil || r.jobProgress.jobCancel == nil {
return
}
r.jobProgress.jobCancel()
}

// Fatal is a wrapper around error to differentiate non-fatal errors from fatal
// ones. A fatal error is typically from a finished context or any error
// returned from a source's Init, Chunks, Enumerate, or ChunkUnit methods.
Expand Down Expand Up @@ -95,6 +107,8 @@ type JobProgress struct {
// Tracks whether the job is finished or not.
ctx context.Context
cancel context.CancelFunc
// Requests to cancel the job.
jobCancel context.CancelFunc
// Metrics.
metrics JobProgressMetrics
metricsLock sync.Mutex
Expand Down Expand Up @@ -135,6 +149,11 @@ func WithHooks(hooks ...JobProgressHook) func(*JobProgress) {
return func(jp *JobProgress) { jp.hooks = append(jp.hooks, hooks...) }
}

// WithCancel allows cancelling the job by the JobProgressRef.
func WithCancel(cancel context.CancelFunc) func(*JobProgress) {
return func(jp *JobProgress) { jp.jobCancel = cancel }
}

// NewJobProgress creates a new job report for the given source and job ID.
func NewJobProgress(jobID, sourceID int64, sourceName string, opts ...func(*JobProgress)) *JobProgress {
ctx, cancel := context.WithCancel(context.Background())
Expand Down
4 changes: 3 additions & 1 deletion pkg/sources/source_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,15 @@ func (s *SourceManager) asyncRun(ctx context.Context, handle handle) (JobProgres
return JobProgressRef{SourceID: int64(handle), SourceName: sourceName}, err
}
// Create a JobProgress object for tracking progress.
progress := NewJobProgress(jobID, int64(handle), sourceName, WithHooks(s.hooks...))
ctx, cancel := context.WithCancel(ctx)
progress := NewJobProgress(jobID, int64(handle), sourceName, WithHooks(s.hooks...), WithCancel(cancel))
s.pool.Go(func() error {
ctx := context.WithValues(ctx,
"job_id", jobID,
"source_manager_worker_id", common.RandomID(5),
)
defer common.Recover(ctx)
defer cancel()
return s.run(ctx, handle, jobID, progress)
})
return progress.Ref(), nil
Expand Down
30 changes: 30 additions & 0 deletions pkg/sources/source_manager_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sources

import (
"errors"
"fmt"
"testing"

Expand Down Expand Up @@ -287,3 +288,32 @@ func TestSourceManagerJobAndSourceIDs(t *testing.T) {
assert.Equal(t, int64(9001), ref.JobID)
assert.Equal(t, "dummy", ref.SourceName)
}

// Chunk method that has a custom callback for the Chunks method.
type callbackChunker struct {
cb func(context.Context, chan *Chunk) error
}

func (c callbackChunker) Chunks(ctx context.Context, ch chan *Chunk) error { return c.cb(ctx, ch) }
func (c callbackChunker) Enumerate(context.Context, UnitReporter) error { return nil }
func (c callbackChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) error { return nil }

func TestSourceManagerCancelRun(t *testing.T) {
mgr := NewManager(WithBufferedOutput(8))
var returnedErr error
handle, err := enrollDummy(mgr, callbackChunker{func(ctx context.Context, _ chan *Chunk) error {
// The context passed to Chunks should get cancelled when ref.CancelRun() is called.
<-ctx.Done()
returnedErr = fmt.Errorf("oh no: %w", ctx.Err())
return returnedErr
}})
assert.NoError(t, err)

ref, err := mgr.ScheduleRun(context.Background(), handle)
assert.NoError(t, err)

ref.CancelRun()
<-ref.Done()
assert.Error(t, ref.Snapshot().FatalError())
assert.True(t, errors.Is(ref.Snapshot().FatalError(), returnedErr))
}

0 comments on commit 5eb776c

Please sign in to comment.