Skip to content

Commit

Permalink
taskgroup: add a Gatherer type (#8)
Browse files Browse the repository at this point in the history
The Gatherer is intended to replace the Collector.  Instead of wrapping tasks
that then have to be given to a Group, the Gatherer manages a run function.

Also:
- Add basic tests
- Update README.md examples to use Gatherer
- Mark Collector as deprecated
- Isolate NoError so it can be deprecated later
  • Loading branch information
creachadair authored Oct 6, 2024
1 parent f9ec917 commit b85b384
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 48 deletions.
72 changes: 36 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Here is a [working example in the Go Playground](https://go.dev/play/p/wCZzMDXRU
- [Filtering Errors](#filtering-errors)
- [Controlling Concurrency](#controlling-concurrency)
- [Solo Tasks](#solo-tasks)
- [Collecting Results](#collecting-results)
- [Gathering Results](#gathering-results)

## Rationale

Expand Down Expand Up @@ -303,67 +303,67 @@ if err != nil {
doThingsWith(data)
```

## Collecting Results
## Gathering Results

One common use for a background task is accumulating the results from a batch
of concurrent workers. This could be handled by a solo task, as described
above, but it is a common enough pattern that the library provides a
`Collector` type to handle it specifically.
above, but it is a common enough pattern that the library provides a `Gatherer`
type to handle it specifically.

To use it, pass a function to `Collect` to receive the values:
To use it, pass a function to `Gather` to receive the values:

```go
var g taskgroup.Group

var sum int
c := taskgroup.Collect(func(v int) { sum += v })
c := taskgroup.Gather(g.Go, func(v int) { sum += v })
```

The `Call`, `Run`, and `Report` methods of `c` can now be used to wrap
functions that yield values, to deliver those values to `c`:
The `Call`, `Run`, and `Report` methods of `c` can now be used to start tasks
in `g` that yield values, and deliver those values to the accumulator:

- `c.Call` takes a `func() (T, error)`, returning a value and an error.
- `c.Run` takes a `func() T`, returning only a value.
If the task reports an error, that error is returned as usual. Otherwise,
its non-error value is gathered by the callback.

If the wrapped function reports an error, that error is returned from the task
as usual. Otherwise, its non-error value is given to the accumulator callback.
As in the above example, calls to the function are serialized so that it is
safe to access state without additional locking:
- `c.Run` takes a `func() T`, returning only a value, which is gathered by the
callback.

```go
var g taskgroup.Group
// ...
- `c.Report` takes a `func(func(T)) error`, which allows a task to report
_multiple_ values to the gatherer via a "report" callback. The task itself
returns only an `error`, but it may call its argument any number of times to
gather values.

Calls to the callback are serialized so that it is safe to access state without
additional locking:

// Report an error, no value is sent to the collector.
g.Go(c.Call(func() (int, error) {
```go
// Report an error, no value is gathered.
c.Call(func() (int, error) {
return -1, errors.New("bad")
}))
})

// No error, send the value 25 to the collector.
g.Go(c.Call(func() (int, error) {
// No error, send gather the value 25.
c.Call(func() (int, error) {
return 25, nil
}))

// Send a random integer to the collector.
g.Go(c.Run(func() int { return rand.Intn(1000) })
```
})

The `Report` method allows a task to report _multiple_ values to the collector
via a callback. Here, the function returns only an `error`, but it receives a
callback it may invoke any number of times to send values:
// Gather a random integer.
c.Run(func() int { return rand.Intn(1000) })

```go
// Send the values 10, 20, and 30 to the collector.
// Gather the values 10, 20, and 30.
//
// Note that even if the function reports an error, any values it sent to
// the collector before returning are still delivered.
g.Go(c.Report(func(report func(int)) error {
// Note that even if the function reports an error, any values it sent
// before returning are still gathered.
c.Report(func(report func(int)) error {
report(10)
report(20)
report(30)
return nil
}))
})
```

Once all the tasks derived from the collector are done, it is safe to access
Once all the tasks passed to the gatherer are complete, it is safe to access
the values accumulated by the callback:

```go
Expand Down
63 changes: 62 additions & 1 deletion collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import "sync"

// A Collector collects values reported by task functions and delivers them to
// an accumulator function.
//
// Deprecated: Use a [Gatherer] instead.
type Collector[T any] struct {
μ sync.Mutex
handle func(T)
Expand All @@ -22,6 +24,8 @@ func (c *Collector[T]) report(v T) {
//
// The tasks created from a collector do not return until all the values
// reported by the underlying function have been processed by the accumulator.
//
// Deprecated: Use [Gather] instead.
func Collect[T any](value func(T)) *Collector[T] { return &Collector[T]{handle: value} }

// Call returns a Task wrapping a call to f. If f reports an error, that error
Expand All @@ -48,5 +52,62 @@ func (c *Collector[T]) Report(f func(report func(T)) error) Task {
// Run returns a Task wrapping a call to f. The resulting task reports a nil
// error for all calls.
func (c *Collector[T]) Run(f func() T) Task {
return NoError(func() { c.report(f()) })
return noError(func() { c.report(f()) })
}

// A Gatherer manages a group of [Task] functions that report values, and
// gathers the values they return.
type Gatherer[T any] struct {
run func(Task) // start the task in a goroutine

μ sync.Mutex
gather func(T) // handle values reported by tasks
}

func (g *Gatherer[T]) report(v T) {
g.μ.Lock()
defer g.μ.Unlock()
g.gather(v)
}

// Gather creates a new empty gatherer that uses run to execute tasks returning
// values of type T.
//
// If gather != nil, values reported by successful tasks are passed to the
// function, otherwise such values are discarded. Calls to gather are
// synchronized to a single goroutine.
//
// If run == nil, Gather will panic.
func Gather[T any](run func(Task), gather func(T)) *Gatherer[T] {
if run == nil {
panic("run function is nil")
}
if gather == nil {
gather = func(T) {}
}
return &Gatherer[T]{run: run, gather: gather}
}

// Call runs f in g. If f reports an error, the error is propagated to the
// runner; otherwise the non-error value reported by f is gathered.
func (g *Gatherer[T]) Call(f func() (T, error)) {
g.run(func() error {
v, err := f()
if err == nil {
g.report(v)
}
return err
})
}

// Run runs f in g, and gathers the value it reports.
func (g *Gatherer[T]) Run(f func() T) {
g.run(func() error { g.report(f()); return nil })
}

// Report runs f in g. Any values passed to report are gathered. If f reports
// an error, that error is propagated to the runner. Any values sent before f
// returns are still gathered, even if f reports an error.
func (g *Gatherer[T]) Report(f func(report func(T)) error) {
g.run(func() error { return f(g.report) })
}
8 changes: 2 additions & 6 deletions single.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@ func Go[T any](task func() T) *Single[T] {
}

// Run runs task in a new goroutine. The caller must call Wait to wait for the
// task to return and collect its error. This is shorthand for:
//
// taskgroup.Go(taskgroup.NoError(task))
//
// The error reported by Wait is always nil.
func Run(task func()) *Single[error] { return Go(NoError(task)) }
// task to return. The error reported by Wait is always nil.
func Run(task func()) *Single[error] { return Go(noError(task)) }

// Call starts task in a new goroutine. The caller must call Wait to wait for
// the task to return and collect its result.
Expand Down
10 changes: 5 additions & 5 deletions taskgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,9 @@ func (g *Group) Go(task Task) {
}()
}

// Run runs task in a new goroutine in g, and returns g to permit chaining.
// This is shorthand for:
//
// g.Go(taskgroup.NoError(task))
func (g *Group) Run(task func()) { g.Go(NoError(task)) }
// Run runs task in a new goroutine in g.
// The resulting task reports a nil error.
func (g *Group) Run(task func()) { g.Go(noError(task)) }

func (g *Group) handleError(err error) {
g.μ.Lock()
Expand Down Expand Up @@ -195,6 +193,8 @@ func Listen(f func(error)) any { return f }
// NoError adapts f to a Task that executes f and reports a nil error.
func NoError(f func()) Task { return func() error { f(); return nil } }

func noError(f func()) Task { return func() error { f(); return nil } }

// Limit returns g and a "start" function that starts each task passed to it in
// g, allowing no more than n tasks to be active concurrently. If n ≤ 0, no
// limit is enforced.
Expand Down
70 changes: 70 additions & 0 deletions taskgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package taskgroup_test
import (
"context"
"errors"
"math"
"math/rand/v2"
"reflect"
"sync"
Expand Down Expand Up @@ -336,6 +337,75 @@ func TestCollector_Report(t *testing.T) {
}
}

func TestGatherer(t *testing.T) {
defer leaktest.Check(t)()

g, run := taskgroup.New(nil).Limit(4)
checkWait := func(t *testing.T) {
t.Helper()
if err := g.Wait(); err != nil {
t.Errorf("Unexpected error from Wait: %v", err)
}
}

t.Run("Call", func(t *testing.T) {
var sum int
r := taskgroup.Gather(run, func(v int) {
sum += v
})

for _, v := range rand.Perm(15) {
r.Call(func() (int, error) {
if v > 10 {
return -100, errors.New("don't add this")
}
return v, nil
})
}

g.Wait()
if want := (10 * 11) / 2; sum != want {
t.Errorf("Final result: got %d, want %d", sum, want)
}
})

t.Run("Run", func(t *testing.T) {
var sum int
r := taskgroup.Gather(run, func(v int) {
sum += v
})
for _, v := range rand.Perm(15) {
r.Run(func() int { return v + 1 })
}

checkWait(t)
if want := (15 * 16) / 2; sum != want {
t.Errorf("Final result: got %d, want %d", sum, want)
}
})

t.Run("Report", func(t *testing.T) {
var sum uint32
r := taskgroup.Gather(g.Go, func(v uint32) {
sum |= v
})

for _, i := range rand.Perm(32) {
r.Report(func(report func(v uint32)) error {
for _, v := range rand.Perm(i + 1) {
report(uint32(1 << v))
}
return nil
})
}

checkWait(t)
if sum != math.MaxUint32 {
t.Errorf("Final result: got %d, want %d", sum, math.MaxUint32)
}
})
}

type peakValue struct {
μ sync.Mutex
cur, max int
Expand Down

0 comments on commit b85b384

Please sign in to comment.