diff --git a/resourcemanager/pooltask/task.go b/resourcemanager/pooltask/task.go index 7dc0898129386..ef9b046c8ccba 100644 --- a/resourcemanager/pooltask/task.go +++ b/resourcemanager/pooltask/task.go @@ -121,6 +121,11 @@ func (t *TaskController[T, U, C, CT, TF]) Wait() { close(t.resultCh) } +// TaskID is to get the task id. +func (t *TaskController[T, U, C, CT, TF]) TaskID() uint64 { + return t.taskID +} + // Task is a task that can be executed. type Task[T any] struct { Task T diff --git a/util/gpool/spmc/spmcpool.go b/util/gpool/spmc/spmcpool.go index 19f2c2cc63da1..a81f423ab8564 100644 --- a/util/gpool/spmc/spmcpool.go +++ b/util/gpool/spmc/spmcpool.go @@ -196,6 +196,7 @@ func (p *Pool[T, U, C, CT, TF]) release() { // There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent // those callers blocking infinitely. p.cond.Broadcast() + close(p.taskCh) } func isClose(exitCh chan struct{}) bool { diff --git a/util/gpool/spmc/spmcpool_test.go b/util/gpool/spmc/spmcpool_test.go index 5aa54313274ed..984f501789c47 100644 --- a/util/gpool/spmc/spmcpool_test.go +++ b/util/gpool/spmc/spmcpool_test.go @@ -191,18 +191,22 @@ func TestPoolWithoutEnoughCapacityParallel(t *testing.T) { p.SetConsumerFunc(func(a struct{}, b int, c any) struct{} { return struct{}{} }) - var twg util.WaitGroupWrapper + var twg sync.WaitGroup for i := 0; i < 10; i++ { - twg.Run(func() { + twg.Add(1) + go func() { + defer twg.Done() sema := make(chan struct{}, 10) - var wg util.WaitGroupWrapper + var wg sync.WaitGroup exitCh := make(chan struct{}) - wg.Run(func() { + wg.Add(1) + go func() { + wg.Done() for j := 0; j < RunTimes; j++ { sema <- struct{}{} } close(exitCh) - }) + }() producerFunc := func() (struct{}, error) { for { select { @@ -218,14 +222,15 @@ func TestPoolWithoutEnoughCapacityParallel(t *testing.T) { } } resultCh, ctl := p.AddProducer(producerFunc, RunTimes, pooltask.NilContext{}, WithConcurrency(concurrency)) - - wg.Run(func() { + wg.Add(1) + go func() { + defer wg.Done() for range resultCh { } - }) + }() ctl.Wait() wg.Wait() - }) + }() } twg.Wait() } @@ -240,14 +245,16 @@ func TestBenchPool(t *testing.T) { for i := 0; i < 1000; i++ { sema := make(chan struct{}, 10) - var wg util.WaitGroupWrapper + var wg sync.WaitGroup exitCh := make(chan struct{}) - wg.Run(func() { + wg.Add(1) + go func() { + defer wg.Done() for j := 0; j < RunTimes; j++ { sema <- struct{}{} } close(exitCh) - }) + }() producerFunc := func() (struct{}, error) { for { select {