Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ fmt.Printf("Duration: %v\n", task.Duration())

The library supports several common concurrency patterns out of the box:

- **Worker Pools**](#worker-pools)** - Controlled concurrency with `Consume` and `InvokeAll`
- **Worker Pools** - Controlled concurrency with `Consume` and `InvokeAll`
- **Fork/Join** - Parallel task execution with result aggregation
- **Throttling** - Rate limiting with `Consume` and custom concurrency
- **Repeating** - Periodic execution with `Repeat`
- **Task Chaining** - Sequential execution with `After` for processing pipelines



Expand Down Expand Up @@ -186,6 +187,36 @@ time.Sleep(5 * time.Minute)
heartbeat.Cancel()
```

## Task Chaining

Task chaining enables sequential execution of dependent operations where the output of one task becomes the input of the next. This pattern is essential for creating processing pipelines, implementing workflows, or building reactive systems where operations must happen in a specific order. The `After` function provides a clean, functional approach to task composition with automatic execution and result propagation.

```go
// Create a processing pipeline
task1 := async.NewTask(func(ctx context.Context) (string, error) {
// Fetch raw data
return "raw data", nil
})

task2 := async.After(task1, func(ctx context.Context, data string) (string, error) {
// Process the raw data
return "processed: " + data, nil
})

task3 := async.After(task2, func(ctx context.Context, processed string) (string, error) {
// Final transformation
return "final: " + processed, nil
})

// Start the chain by running the first task
task1.Run(context.Background())

// Get the final result
result, err := task3.Outcome()
// result will be "final: processed: raw data"
```


## Benchmarks

The benchmarks demonstrate the library's excellent performance characteristics across different usage patterns.
Expand Down
14 changes: 5 additions & 9 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,11 @@ import (

/*
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
BenchmarkTask/Consume-24 1125 974459 ns/op 1122649 B/op 16025 allocs/op
BenchmarkTask/Invoke-24 1000000 1105 ns/op 528 B/op 7 allocs/op
BenchmarkTask/InvokeAll-24 1298 938814 ns/op 1114347 B/op 16023 allocs/op

BenchmarkTask/Consume-24 4054 309833 ns/op 145127 B/op 2014 allocs/op
BenchmarkTask/Invoke-24 2361956 507.6 ns/op 128 B/op 2 allocs/op
BenchmarkTask/InvokeAll-24 4262 303242 ns/op 161449 B/op 2015 allocs/op
BenchmarkTask/Completed-24 89886966 13.36 ns/op 32 B/op 1 allocs/op
BenchmarkTask/Errored-24 89026714 13.50 ns/op 32 B/op 1 allocs/op
BenchmarkTask/Consume-24 3796 318294 ns/op 145122 B/op 2014 allocs/op
BenchmarkTask/Invoke-24 2116862 570.9 ns/op 128 B/op 2 allocs/op
BenchmarkTask/InvokeAll-24 3760 336794 ns/op 161456 B/op 2015 allocs/op
BenchmarkTask/Completed-24 86476514 13.87 ns/op 32 B/op 1 allocs/op
BenchmarkTask/Errored-24 86474020 14.00 ns/op 32 B/op 1 allocs/op
*/
func BenchmarkTask(b *testing.B) {
b.Run("Consume", func(b *testing.B) {
Expand Down
65 changes: 58 additions & 7 deletions task.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ type outcome[T any] struct {
err error // The error
}

// Task represents a unit of work to be done
// task represents a unit of work to be done
type task[T any] struct {
state int32 // This indicates whether the task is started or not
duration int64 // The duration of the task, in nanoseconds
wg sync.WaitGroup // Used to wait for completion instead of channel
action Work[T] // The work to do
outcome outcome[T] // This is used to store the result
state int32 // This indicates whether the task is started or not
duration int64 // The duration of the task, in nanoseconds
wg sync.WaitGroup // Used to wait for completion instead of channel
action Work[T] // The work to do
outcome outcome[T] // This is used to store the result
chain atomic.Pointer[chain] // Continuation functions
}

// Awaiter is an interface that can be used to wait for a task to complete.
Expand All @@ -70,7 +71,7 @@ func NewTask[T any](action Work[T]) Task[T] {
t := &task[T]{
action: action,
}
t.wg.Add(1) // Will be Done() when task completes
t.wg.Add(1)
return t
}

Expand Down Expand Up @@ -164,8 +165,16 @@ func (t *task[T]) run(ctx context.Context) {
}
}()

// Execute the task
r, e := t.action(ctx)
t.outcome = outcome[T]{result: r, err: e}

// Run next tasks with the same context
if cont := t.chain.Load(); cont != nil {
for _, next := range *cont {
next(ctx)
}
}
}()

atomic.StoreInt64(&t.duration, now().UnixNano()-startedAt)
Expand Down Expand Up @@ -244,3 +253,45 @@ func Completed[T any](result T) Task[T] {
func Failed[T any](err error) Task[T] {
return &completedTask[T]{err: err}
}

// -------------------------------- Continuation Task --------------------------------

type chain = []func(context.Context)

// After creates a continuation task that automatically runs when the predecessor completes
func After[T, U any](predecessor Task[T], work func(context.Context, T) (U, error)) Task[U] {
prev, ok := predecessor.(*task[T])
if !ok {
return Failed[U](fmt.Errorf("predecessor does not support chaining"))
}

// Since this function is only called after predecessor completes,
// we can directly access its outcome without waiting
next := NewTask(func(ctx context.Context) (U, error) {
if prev.outcome.err != nil {
var zero U
return zero, prev.outcome.err
}

return work(ctx, prev.outcome.result) //nolint:scopelint
}).(*task[U])

// Add continuation function using atomic operations
for {
curr := prev.chain.Load()
cont := withNext(curr, next.run)
if prev.chain.CompareAndSwap(curr, &cont) {
break
}
}
return next
}

// withNext adds a new continuation function to the list of continuations
func withNext(current *chain, next func(context.Context)) chain {
if current == nil {
return chain{next}
}

return append((*current), next)
}
145 changes: 145 additions & 0 deletions task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,148 @@ func TestWait(t *testing.T) {
assert.Error(t, task.Wait())
assert.Error(t, task.Wait())
}

func TestAfterBasicChaining(t *testing.T) {
result1 := "first task"
result2 := "second task"

// Create first task
task1 := NewTask(func(ctx context.Context) (string, error) {
time.Sleep(time.Millisecond * 10)
return result1, nil
})

// Chain second task after first
task2 := After(task1, func(ctx context.Context, result1 string) (any, error) {
time.Sleep(time.Millisecond * 10)
return result2, nil
})

// Start the first task (this will trigger the chain)
task1.Run(context.Background())

// Verify both tasks completed
firstResult, err1 := task1.Outcome()
assert.NoError(t, err1)
assert.Equal(t, result1, firstResult)

secondResult, err2 := task2.Outcome()
assert.NoError(t, err2)
assert.Equal(t, result2, secondResult)

// Verify both tasks have durations
assert.True(t, task1.Duration() > 0)
assert.True(t, task2.Duration() > 0)
}

func TestAfterWithError(t *testing.T) {
expectedError := errors.New("first task failed")

// Create first task that fails
task1 := NewTask(func(ctx context.Context) (string, error) {
return "", expectedError
})

// Chain second task after first (should still run)
task2 := After(task1, func(ctx context.Context, result1 string) (any, error) {
return "second task succeeded", nil
})

// Start the first task
task1.Run(context.Background())

// Verify first task failed
_, err1 := task1.Outcome()
assert.Error(t, err1)
assert.Equal(t, expectedError, err1)

// Verify second task is showing the same error
_, err2 := task2.Outcome()
assert.Error(t, err2)
assert.Equal(t, expectedError, err2)
}

func TestAfterMultipleChaining(t *testing.T) {
// Create a chain of tasks
task1 := NewTask(func(ctx context.Context) (string, error) {
return "task1", nil
})

task2 := After(task1, func(ctx context.Context, result1 string) (any, error) {
return "task2", nil
})

task3 := After(task2, func(ctx context.Context, result2 any) (any, error) {
return "task3", nil
})

// Start the first task (this will trigger the entire chain)
task1.Run(context.Background())

// Verify all tasks completed
result1, err1 := task1.Outcome()
assert.NoError(t, err1)
assert.Equal(t, "task1", result1)

result2, err2 := task2.Outcome()
assert.NoError(t, err2)
assert.Equal(t, "task2", result2)

result3, err3 := task3.Outcome()
assert.NoError(t, err3)
assert.Equal(t, "task3", result3)
}

func TestAfterWithCancellation(t *testing.T) {
task1 := NewTask(func(ctx context.Context) (string, error) {
time.Sleep(time.Millisecond * 10)
return "task1", nil
})

task2 := After(task1, func(ctx context.Context, result1 string) (any, error) {
time.Sleep(time.Millisecond * 50)
return "task2", nil
})

// Start the first task
task1.Run(context.Background())

// Cancel the continuation task while it's running
time.Sleep(time.Millisecond * 15) // Let first task complete
task2.Cancel()

// Verify first task completed
result1, err1 := task1.Outcome()
assert.NoError(t, err1)
assert.Equal(t, "task1", result1)

// Verify second task was cancelled
_, err2 := task2.Outcome()
assert.Error(t, err2)
assert.Equal(t, errCancelled, err2)
}

func TestAfterWithCompletedTask(t *testing.T) {
// Create a completed task
task1 := Completed("completed result")

// Chain after it
task2 := After(task1, func(ctx context.Context, result1 string) (any, error) {
return "continuation", nil
})

// Verify results
result1, err1 := task1.Outcome()
assert.NoError(t, err1)
assert.Equal(t, "completed result", result1)

result2, err2 := task2.Outcome()
if err2 != nil {
// If chaining is not supported, should return an error
assert.Error(t, err2)
assert.Contains(t, err2.Error(), "predecessor does not support chaining")
} else {
// If chaining is supported, should work
assert.Equal(t, "continuation", result2)
}
}