diff --git a/taskgroup_test.go b/taskgroup_test.go index dcf4bf3..0de4e3e 100644 --- a/taskgroup_test.go +++ b/taskgroup_test.go @@ -3,6 +3,8 @@ package taskgroup_test import ( "context" "errors" + "fmt" + "math" "math/rand/v2" "reflect" "sync" @@ -336,6 +338,75 @@ func TestCollector_Report(t *testing.T) { } } +func TestGatherer(t *testing.T) { + defer leaktest.Check(t)() + + g, run := taskgroup.New(nil).Limit(4) + checkWait := func(t *testing.T) { + t.Helper() + if err := g.Wait(); err != nil { + t.Errorf("Unexpected error from Wait: %v", err) + } + } + + t.Run("Call", func(t *testing.T) { + var sum int + r := taskgroup.Gather(run, func(v int) { + sum += v + }) + + for _, v := range rand.Perm(15) { + r.Call(func() (int, error) { + if v > 10 { + return -100, errors.New("don't add this") + } + return v, nil + }) + } + + g.Wait() + if want := (10 * 11) / 2; sum != want { + t.Errorf("Final result: got %d, want %d", sum, want) + } + }) + + t.Run("Run", func(t *testing.T) { + var sum int + r := taskgroup.Gather(run, func(v int) { + sum += v + }) + for _, v := range rand.Perm(15) { + r.Run(func() int { return v + 1 }) + } + + checkWait(t) + if want := (15 * 16) / 2; sum != want { + t.Errorf("Final result: got %d, want %d", sum, want) + } + }) + + t.Run("Report", func(t *testing.T) { + var sum uint32 + r := taskgroup.Gather(g.Go, func(v uint32) { + sum |= v + }) + + for _, i := range rand.Perm(32) { + r.Report(func(report func(v uint32)) error { + for _, v := range rand.Perm(i + 1) { + report(uint32(1 << v)) + } + return nil + }) + } + + checkWait(t) + if sum != math.MaxUint32 { + t.Errorf("Final result: got %d, want %d", sum, math.MaxUint32) + } + }) +} + type peakValue struct { μ sync.Mutex cur, max int @@ -355,3 +426,48 @@ func (p *peakValue) dec() { p.cur-- p.μ.Unlock() } + +func TestTree(t *testing.T) { + defer leaktest.Check(t)() + + vs := rand.Perm(1000) + + g, run := taskgroup.New(nil).Limit(5) + + type result [3]int + r := taskgroup.Gather(run, func(v result) { + t.Logf("+ %d at %d: %d", v[0], v[1], v[2]) + }) + + for i := range vs { + r.Run(func() result { + // Find the location of i in the permutation. + for j, v := range vs { + if v != i { + continue + } + + // Count the number of things less than i earlier in vs than i. + // Do this in the most inefficient possible way. + g, run := taskgroup.New(nil).Limit(5) + var countLess int + r := taskgroup.Gather(run, func(int) { + countLess++ + }) + + for k := range j { + r.Call(func() (int, error) { + if vs[k] < v { + return k, nil + } + return -1, errors.New("no") + }) + } + g.Wait() + return result{i, j, countLess} + } + panic(fmt.Sprintf("%d not found", i)) + }) + } + g.Wait() +}