From 67bfd1053f0c40fb10ce100b65794042545e68db Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Tue, 16 Jan 2024 12:45:55 +0100 Subject: [PATCH] [FIXED] Race condition when resetting ordered consumer Signed-off-by: Piotr Piotrowski --- jetstream/ordered.go | 16 +++++++++------- jetstream/pull.go | 9 +++------ jetstream/test/ordered_test.go | 19 +++++++++++++++++-- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/jetstream/ordered.go b/jetstream/ordered.go index e4e8bde1f..32f7b2495 100644 --- a/jetstream/ordered.go +++ b/jetstream/ordered.go @@ -139,6 +139,13 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt for { select { case <-c.doReset: + if err := c.reset(); err != nil { + sub, ok := c.currentConsumer.getSubscription("") + if !ok { + return + } + c.errHandler(c.serial)(sub, err) + } if c.withStopAfter { select { case c.stopAfter = <-c.stopAfterMsgsLeft: @@ -149,13 +156,6 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt return } } - if err := c.reset(); err != nil { - sub, ok := c.currentConsumer.getSubscription("") - if !ok { - return - } - c.errHandler(c.serial)(sub, err) - } if c.stopAfter > 0 { opts = opts[:len(opts)-2] } else { @@ -190,6 +190,8 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt func (c *orderedConsumer) errHandler(serial int) func(cc ConsumeContext, err error) { return func(cc ConsumeContext, err error) { + c.Lock() + defer c.Unlock() if c.userErrHandler != nil && !errors.Is(err, errOrderedSequenceMismatch) { c.userErrHandler(cc, err) } diff --git a/jetstream/pull.go b/jetstream/pull.go index 44e6b3654..856e6e7ed 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -649,12 +649,12 @@ func (s *pullSubscription) Next() (Msg, error) { func (s *pullSubscription) handleStatusMsg(msg *nats.Msg, msgErr error) error { if !errors.Is(msgErr, nats.ErrTimeout) && !errors.Is(msgErr, ErrMaxBytesExceeded) { - if s.consumeOpts.ErrHandler != nil { - s.consumeOpts.ErrHandler(s, msgErr) - } if errors.Is(msgErr, ErrConsumerDeleted) || errors.Is(msgErr, ErrBadRequest) { return msgErr } + if s.consumeOpts.ErrHandler != nil { + s.consumeOpts.ErrHandler(s, msgErr) + } if errors.Is(msgErr, ErrConsumerLeadershipChanged) { s.pending.msgCount = 0 s.pending.byteCount = 0 @@ -663,9 +663,6 @@ func (s *pullSubscription) handleStatusMsg(msg *nats.Msg, msgErr error) error { } msgsLeft, bytesLeft, err := parsePending(msg) if err != nil { - if s.consumeOpts.ErrHandler != nil { - s.consumeOpts.ErrHandler(s, err) - } return err } s.pending.msgCount -= msgsLeft diff --git a/jetstream/test/ordered_test.go b/jetstream/test/ordered_test.go index c8b529f16..312cc8ebf 100644 --- a/jetstream/test/ordered_test.go +++ b/jetstream/test/ordered_test.go @@ -365,16 +365,25 @@ func TestOrderedConsumerConsume(t *testing.T) { } for i := 0; i < 100; i++ { - if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil { + if _, err := js.Publish(ctx, "FOO.A", []byte(fmt.Sprintf("msg%d", i))); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } msgs := make([]jetstream.Msg, 0) wg := &sync.WaitGroup{} wg.Add(100) + mu := &sync.Mutex{} + var i int _, err = c.Consume(func(msg jetstream.Msg) { msgs = append(msgs, msg) msg.Ack() + mu.Lock() + i++ + if i > 150 { + mu.Unlock() + return + } + mu.Unlock() wg.Done() }, jetstream.StopAfter(150), jetstream.PullMaxMessages(40)) if err != nil { @@ -386,11 +395,17 @@ func TestOrderedConsumerConsume(t *testing.T) { } wg.Add(50) for i := 0; i < 100; i++ { - if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil { + if _, err := js.Publish(ctx, "FOO.A", []byte(fmt.Sprintf("msg%d", i+100))); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } wg.Wait() + time.Sleep(10 * time.Millisecond) + mu.Lock() + if i > 150 { + t.Fatalf("Unexpected number of messages; want 150; got %d", i) + } + mu.Unlock() time.Sleep(10 * time.Millisecond) ci, err := c.Info(ctx)