From e5c1cd06eb3e17c8325a0780f5aa904381efc686 Mon Sep 17 00:00:00 2001 From: Yusheng Li Date: Sun, 13 Oct 2024 15:06:04 +0800 Subject: [PATCH] feat(worker): goroutine pool (#43) --- app/app.go | 5 ++- config.yml | 3 ++ config/worker.go | 6 +++ pkg/pool/pool.go | 90 +++++++++++++++++++++++++++++++++++++++++++ pkg/pool/pool_test.go | 80 ++++++++++++++++++++++++++++++++++++++ pkg/pool/task.go | 26 +++++++++++++ worker/worker.go | 35 ++++++++++------- 7 files changed, 230 insertions(+), 15 deletions(-) create mode 100644 pkg/pool/pool.go create mode 100644 pkg/pool/pool_test.go create mode 100644 pkg/pool/task.go diff --git a/app/app.go b/app/app.go index ed2b0ae..f654222 100644 --- a/app/app.go +++ b/app/app.go @@ -87,7 +87,10 @@ func (app *Application) initialize() error { // worker if cfg.WorkerConfig.Enabled { - opts := worker.WorkerOptions{} + opts := worker.WorkerOptions{ + PoolSize: int(cfg.WorkerConfig.Pool.Size), + PoolConcurrency: int(cfg.WorkerConfig.Pool.Concurrency), + } deliverer := deliverer.NewHTTPDeliverer(&cfg.WorkerConfig.Deliverer) app.worker = worker.NewWorker(opts, db, deliverer, queue) } diff --git a/config.yml b/config.yml index 7182e03..681100f 100644 --- a/config.yml +++ b/config.yml @@ -37,6 +37,9 @@ worker: enabled: false deliverer: timeout: 60000 + pool: + size: 10000 # pool size, default to 10000. + concurrency: 0 # pool concurrency, default to 100 * CPUs #------------------------------------------------------------------------------ # PROXY diff --git a/config/worker.go b/config/worker.go index 96f94c7..61c335a 100644 --- a/config/worker.go +++ b/config/worker.go @@ -4,9 +4,15 @@ type WorkerDeliverer struct { Timeout int64 `yaml:"timeout" default:"60000"` } +type Pool struct { + Size uint32 `yaml:"size" default:"10000"` + Concurrency uint32 `yaml:"concurrency"` +} + type WorkerConfig struct { Enabled bool `yaml:"enabled" default:"false"` Deliverer WorkerDeliverer `yaml:"deliverer"` + Pool Pool `yaml:"pool"` } func (cfg *WorkerConfig) Validate() error { diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go new file mode 100644 index 0000000..b565e65 --- /dev/null +++ b/pkg/pool/pool.go @@ -0,0 +1,90 @@ +package pool + +import ( + "context" + "errors" + "sync" + "time" +) + +var ( + ErrPoolTernimated = errors.New("pool is ternimated") + ErrTimeout = errors.New("timeout") +) + +type Pool struct { + ctx context.Context + cancel context.CancelFunc + + workers int + + tasks chan Task + wait sync.WaitGroup +} + +func NewPool(size int, workers int) *Pool { + ctx, cancel := context.WithCancel(context.Background()) + pool := &Pool{ + ctx: ctx, + cancel: cancel, + workers: workers, + tasks: make(chan Task, size), + } + + pool.wait.Add(workers) + + for i := 0; i < workers; i++ { + go pool.consume() + } + + return pool +} + +func (p *Pool) SubmitFn(timeout time.Duration, fn func()) error { + if fn == nil { + return errors.New("fn is nil") + } + + taks := &task{ + fn: fn, + } + return p.Submit(timeout, taks) +} + +func (p *Pool) Submit(timeout time.Duration, task Task) error { + if task == nil { + return errors.New("task is nil") + } + + if p.ctx.Err() != nil { + return ErrPoolTernimated + } + + select { + case p.tasks <- task: + return nil + case <-time.After(timeout): + return ErrTimeout + } +} + +func (p *Pool) consume() { + defer p.wait.Done() + for { + select { + case <-p.ctx.Done(): + return + case t := <-p.tasks: + t.Execute() + } + } +} + +func (p *Pool) Shutdown() { + if err := p.ctx.Err(); err != nil { + return + } + + p.cancel() + p.wait.Wait() +} diff --git a/pkg/pool/pool_test.go b/pkg/pool/pool_test.go new file mode 100644 index 0000000..7d1648e --- /dev/null +++ b/pkg/pool/pool_test.go @@ -0,0 +1,80 @@ +package pool + +import ( + "github.com/stretchr/testify/assert" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestError(t *testing.T) { + pool := NewPool(0, 1) + + err := pool.SubmitFn(time.Second, nil) + assert.Equal(t, "fn is nil", err.Error()) + + err = pool.Submit(time.Second, nil) + assert.Equal(t, "task is nil", err.Error()) + + // panic should be recovered + err = pool.SubmitFn(time.Second, func() { + panic("foo") + }) + assert.NoError(t, err) + + pool.Shutdown() + pool.Shutdown() // no panic +} + +func TestSubmit(t *testing.T) { + pool := NewPool(5, 1) + wait := sync.WaitGroup{} + for i := 0; i < 5; i++ { + wait.Add(1) + err := pool.SubmitFn(time.Second, func() { + wait.Done() + }) + assert.NoError(t, err) + } + wait.Wait() +} + +func TestSubmitWithTimeout(t *testing.T) { + pool := NewPool(1, 1) + err := pool.SubmitFn(time.Second, func() { + time.Sleep(time.Second * 5) + }) + assert.NoError(t, err) + err = pool.SubmitFn(time.Second, func() { + time.Sleep(time.Second * 5) + }) + assert.NoError(t, err) + err = pool.SubmitFn(time.Second, func() {}) + assert.Equal(t, ErrTimeout, err) +} + +func TestShutdown(t *testing.T) { + pool := NewPool(1, 1) + pool.Shutdown() + err := pool.SubmitFn(time.Second, func() {}) + assert.Equal(t, ErrPoolTernimated, err) +} + +func TestGracefulShutdown(t *testing.T) { + var counter int64 + atomic.StoreInt64(&counter, 0) + + pool := NewPool(100, 100) + + for i := 0; i < 100; i++ { + err := pool.SubmitFn(time.Second, func() { + time.Sleep(time.Second) + atomic.AddInt64(&counter, 1) + }) + assert.NoError(t, err) + } + + pool.Shutdown() + assert.EqualValues(t, 100, counter) // all submitted and scheduled tasks should be executed successfully +} diff --git a/pkg/pool/task.go b/pkg/pool/task.go new file mode 100644 index 0000000..92d2893 --- /dev/null +++ b/pkg/pool/task.go @@ -0,0 +1,26 @@ +package pool + +import ( + "fmt" + "runtime" +) + +type Task interface { + Execute() +} + +type task struct { + fn func() +} + +func (t *task) Execute() { + defer func() { + if e := recover(); e != nil { + buf := make([]byte, 2048) + n := runtime.Stack(buf, false) + buf = buf[:n] + fmt.Printf("panic recovered: %v\n %s\n", e, buf) + } + }() + t.fn() +} diff --git a/worker/worker.go b/worker/worker.go index ca4c1fb..2b22d3d 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -6,15 +6,16 @@ import ( "github.com/webhookx-io/webhookx/constants" "github.com/webhookx-io/webhookx/pkg/plugin" plugintypes "github.com/webhookx-io/webhookx/pkg/plugin/types" - "github.com/webhookx-io/webhookx/pkg/safe" + "github.com/webhookx-io/webhookx/pkg/pool" + "github.com/webhookx-io/webhookx/pkg/schedule" "github.com/webhookx-io/webhookx/pkg/taskqueue" + "runtime" "time" "github.com/webhookx-io/webhookx/db" "github.com/webhookx-io/webhookx/db/dao" "github.com/webhookx-io/webhookx/db/entities" "github.com/webhookx-io/webhookx/model" - "github.com/webhookx-io/webhookx/pkg/schedule" "github.com/webhookx-io/webhookx/pkg/types" "github.com/webhookx-io/webhookx/utils" "github.com/webhookx-io/webhookx/worker/deliverer" @@ -27,30 +28,33 @@ type Worker struct { opts WorkerOptions - stop chan struct{} - log *zap.SugaredLogger + log *zap.SugaredLogger queue taskqueue.TaskQueue deliverer deliverer.Deliverer DB *db.DB + pool *pool.Pool } type WorkerOptions struct { RequeueJobBatch int RequeueJobInterval time.Duration + PoolSize int + PoolConcurrency int } func NewWorker(opts WorkerOptions, db *db.DB, deliverer deliverer.Deliverer, queue taskqueue.TaskQueue) *Worker { opts.RequeueJobBatch = utils.DefaultIfZero(opts.RequeueJobBatch, constants.RequeueBatch) opts.RequeueJobInterval = utils.DefaultIfZero(opts.RequeueJobInterval, constants.RequeueInterval) - + opts.PoolSize = utils.DefaultIfZero(opts.PoolSize, 10000) + opts.PoolConcurrency = utils.DefaultIfZero(opts.PoolConcurrency, runtime.NumCPU()*100) worker := &Worker{ opts: opts, - stop: make(chan struct{}), queue: queue, log: zap.S(), deliverer: deliverer, DB: db, + pool: pool.NewPool(opts.PoolSize, opts.PoolConcurrency), } return worker @@ -62,8 +66,7 @@ func (w *Worker) run() { for { select { - case <-w.stop: - w.log.Info("[worker] receive stop signal") + case <-w.ctx.Done(): return case <-ticker.C: for { @@ -76,7 +79,7 @@ func (w *Worker) run() { break } w.log.Debugf("[worker] receive task: %v", task) - safe.Go(func() { + err = w.pool.SubmitFn(time.Second*10, func() { task.Data = &model.MessageData{} err = task.UnmarshalData(task.Data) if err != nil { @@ -94,6 +97,12 @@ func (w *Worker) run() { _ = w.queue.Delete(context.TODO(), task) }) + if err != nil { + if errors.Is(err, pool.ErrPoolTernimated) { + return // worker is shutting down + } + w.log.Warnf("[worker] failed to submit a task: %v", err) // consider tuning pool configuration + } } } } @@ -105,18 +114,16 @@ func (w *Worker) Start() error { w.ctx, w.cancel = context.WithCancel(context.Background()) schedule.Schedule(w.ctx, w.processRequeue, w.opts.RequeueJobInterval) - + w.log.Infof("[worker] created pool(size=%d, concurrency=%d)", w.opts.PoolSize, w.opts.PoolConcurrency) w.log.Info("[worker] started") return nil } // Stop stops worker func (w *Worker) Stop() error { - // TODO: wait for all go routines finished - w.cancel() - - w.stop <- struct{}{} + w.log.Info("[worker] goroutine pool is shutting down") + w.pool.Shutdown() w.log.Info("[worker] stopped") return nil