Skip to content

Commit

Permalink
[FIXED] Race condition when resetting ordered consumer
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Jan 16, 2024
1 parent a8a8d18 commit 67bfd10
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
16 changes: 9 additions & 7 deletions jetstream/ordered.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
for {
select {
case <-c.doReset:
if err := c.reset(); err != nil {
sub, ok := c.currentConsumer.getSubscription("")
if !ok {
return
}
c.errHandler(c.serial)(sub, err)
}
if c.withStopAfter {
select {
case c.stopAfter = <-c.stopAfterMsgsLeft:
Expand All @@ -149,13 +156,6 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
return
}
}
if err := c.reset(); err != nil {
sub, ok := c.currentConsumer.getSubscription("")
if !ok {
return
}
c.errHandler(c.serial)(sub, err)
}
if c.stopAfter > 0 {
opts = opts[:len(opts)-2]
} else {
Expand Down Expand Up @@ -190,6 +190,8 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt

func (c *orderedConsumer) errHandler(serial int) func(cc ConsumeContext, err error) {
return func(cc ConsumeContext, err error) {
c.Lock()
defer c.Unlock()
if c.userErrHandler != nil && !errors.Is(err, errOrderedSequenceMismatch) {
c.userErrHandler(cc, err)
}
Expand Down
9 changes: 3 additions & 6 deletions jetstream/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,12 +649,12 @@ func (s *pullSubscription) Next() (Msg, error) {

func (s *pullSubscription) handleStatusMsg(msg *nats.Msg, msgErr error) error {
if !errors.Is(msgErr, nats.ErrTimeout) && !errors.Is(msgErr, ErrMaxBytesExceeded) {
if s.consumeOpts.ErrHandler != nil {
s.consumeOpts.ErrHandler(s, msgErr)
}
if errors.Is(msgErr, ErrConsumerDeleted) || errors.Is(msgErr, ErrBadRequest) {
return msgErr
}
if s.consumeOpts.ErrHandler != nil {
s.consumeOpts.ErrHandler(s, msgErr)
}
if errors.Is(msgErr, ErrConsumerLeadershipChanged) {
s.pending.msgCount = 0
s.pending.byteCount = 0
Expand All @@ -663,9 +663,6 @@ func (s *pullSubscription) handleStatusMsg(msg *nats.Msg, msgErr error) error {
}
msgsLeft, bytesLeft, err := parsePending(msg)
if err != nil {
if s.consumeOpts.ErrHandler != nil {
s.consumeOpts.ErrHandler(s, err)
}
return err
}
s.pending.msgCount -= msgsLeft
Expand Down
19 changes: 17 additions & 2 deletions jetstream/test/ordered_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,16 +365,25 @@ func TestOrderedConsumerConsume(t *testing.T) {
}

for i := 0; i < 100; i++ {
if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil {
if _, err := js.Publish(ctx, "FOO.A", []byte(fmt.Sprintf("msg%d", i))); err != nil {
t.Fatalf("Unexpected error during publish: %s", err)
}
}
msgs := make([]jetstream.Msg, 0)
wg := &sync.WaitGroup{}
wg.Add(100)
mu := &sync.Mutex{}
var i int
_, err = c.Consume(func(msg jetstream.Msg) {
msgs = append(msgs, msg)
msg.Ack()
mu.Lock()
i++
if i > 150 {
mu.Unlock()
return
}
mu.Unlock()
wg.Done()
}, jetstream.StopAfter(150), jetstream.PullMaxMessages(40))
if err != nil {
Expand All @@ -386,11 +395,17 @@ func TestOrderedConsumerConsume(t *testing.T) {
}
wg.Add(50)
for i := 0; i < 100; i++ {
if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil {
if _, err := js.Publish(ctx, "FOO.A", []byte(fmt.Sprintf("msg%d", i+100))); err != nil {
t.Fatalf("Unexpected error during publish: %s", err)
}
}
wg.Wait()
time.Sleep(10 * time.Millisecond)
mu.Lock()
if i > 150 {
t.Fatalf("Unexpected number of messages; want 150; got %d", i)
}
mu.Unlock()

time.Sleep(10 * time.Millisecond)
ci, err := c.Info(ctx)
Expand Down

0 comments on commit 67bfd10

Please sign in to comment.