Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions pkg/redisstream/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package redisstream
import (
"context"
"math/rand"
"sort"
"strconv"
"sync"
"testing"
Expand Down Expand Up @@ -343,3 +344,86 @@ func TestClaimIdle(t *testing.T) {

assert.GreaterOrEqual(t, nMsgsWithRetries, 3)
}

func TestSubscriber_ClaimAllMessages(t *testing.T) {
rdb := redisClientOrFail(t)

logger := watermill.NewStdLogger(true, true)

topic := watermill.NewShortUUID()
consumerGroup := watermill.NewShortUUID()

// This one should claim all messages
subGood, err := NewSubscriber(SubscriberConfig{
Client: rdb,
ConsumerGroup: consumerGroup,
Consumer: "good",
MaxIdleTime: 500 * time.Millisecond,
ClaimInterval: 500 * time.Millisecond,
CheckConsumersInterval: 1 * time.Second,
ConsumerTimeout: 2 * time.Second,
}, logger)
require.NoError(t, err)

// This one never acks
subBad, err := NewSubscriber(SubscriberConfig{
Client: rdb,
ConsumerGroup: consumerGroup,
Consumer: "bad",
}, logger)
require.NoError(t, err)

pub, err := NewPublisher(PublisherConfig{
Client: rdb,
}, logger)
require.NoError(t, err)

for i := 0; i < 10; i++ {
err = pub.Publish(topic, message.NewMessage(watermill.NewUUID(), []byte(strconv.Itoa(i))))
assert.NoError(t, err)
}

badCtx, badCancel := context.WithCancel(context.Background())
defer badCancel()

msgs, err := subBad.Subscribe(badCtx, topic)
require.NoError(t, err)

// Pull a message, don't ack it!
<-msgs

// Cancel the bad subscriber
badCancel()

goodCtx, goodCancel := context.WithCancel(context.Background())
defer goodCancel()

msgs, err = subGood.Subscribe(goodCtx, topic)
require.NoError(t, err)

var processedMessages []string

// Try to receive all messages
for i := 0; i < 10; i++ {
select {
case msg, ok := <-msgs:
assert.True(t, ok)
processedMessages = append(processedMessages, string(msg.Payload))
msg.Ack()
case <-time.After(5 * time.Second):
t.Fatal("Timeout waiting to receive all messages")
}
}

sort.Strings(processedMessages)
var expected []string
for i := 0; i < 10; i++ {
expected = append(expected, strconv.Itoa(i))
}
assert.Equal(t, expected, processedMessages)

assert.Eventually(t, func() bool {
xic, _ := rdb.XInfoConsumers(context.Background(), topic, consumerGroup).Result()
return len(xic) == 1 && xic[0].Name == "good"
}, 5*time.Second, 100*time.Millisecond, "Idle consumer should be deleted")
}
91 changes: 72 additions & 19 deletions pkg/redisstream/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,21 @@ const (

DefaultBlockTime = time.Millisecond * 100

// How often to check for dead workers to claim pending messages from
// How often to claim pending messages
DefaultClaimInterval = time.Second * 5

DefaultClaimBatchSize = int64(100)

// Default max idle time for pending message.
// After timeout, the message will be claimed and its idle consumer will be removed from consumer group
// Default max idle time for pending message
// After timeout, the message will be claimed
DefaultMaxIdleTime = time.Second * 60

// How often to check for dead consumers
DefaultCheckConsumersInterval = time.Second * 300

// Default consumer timeout
// After being idle longer than timeout and having no pending messages, it will be removed from the consumer group
DefaultConsumerTimeout = time.Second * 600
)

type Subscriber struct {
Expand Down Expand Up @@ -86,9 +93,15 @@ type SubscriberConfig struct {
// How many pending messages are claimed at most each claim interval
ClaimBatchSize int64

// How long should we treat a consumer as offline
// How long should we treat a pending message as claimable
MaxIdleTime time.Duration

// Check consumer status interval
CheckConsumersInterval time.Duration

// After which time an idle consumer with no pending messages will be removed from the consumer group
ConsumerTimeout time.Duration

// Start consumption from the specified message ID
// When using "0", the consumer group will consume from the very first message
// When using "$", the consumer group will consume from the latest message
Expand Down Expand Up @@ -129,6 +142,12 @@ func (sc *SubscriberConfig) setDefaults() {
if sc.MaxIdleTime == 0 {
sc.MaxIdleTime = DefaultMaxIdleTime
}
if sc.CheckConsumersInterval == 0 {
sc.CheckConsumersInterval = DefaultCheckConsumersInterval
}
if sc.ConsumerTimeout == 0 {
sc.ConsumerTimeout = DefaultConsumerTimeout
}
// Consume from scratch by default
if sc.OldestId == "" {
sc.OldestId = "0"
Expand Down Expand Up @@ -244,9 +263,9 @@ func (s *Subscriber) consumeStreams(ctx context.Context, stream string, output c

func (s *Subscriber) read(ctx context.Context, stream string, readChannel chan<- *redis.XStream, logFields watermill.LogFields) {
wg := &sync.WaitGroup{}
claimCtx, claimCancel := context.WithCancel(ctx)
subCtx, subCancel := context.WithCancel(ctx)
defer func() {
claimCancel()
subCancel()
wg.Wait()
close(readChannel)
}()
Expand All @@ -265,11 +284,15 @@ func (s *Subscriber) read(ctx context.Context, stream string, readChannel chan<-
if s.config.ConsumerGroup != "" {
// 1. get pending message from idle consumer
wg.Add(1)
s.claim(claimCtx, stream, readChannel, false, wg, logFields)
s.claim(subCtx, stream, readChannel, false, wg, logFields)

// 2. background
wg.Add(1)
go s.claim(claimCtx, stream, readChannel, true, wg, logFields)
go s.claim(subCtx, stream, readChannel, true, wg, logFields)

// check consumer status and remove idling consumers if possible
wg.Add(1)
go s.checkConsumers(subCtx, stream, wg, logFields)
}

for {
Expand Down Expand Up @@ -327,7 +350,6 @@ func (s *Subscriber) read(ctx context.Context, stream string, readChannel chan<-
}

func (s *Subscriber) claim(ctx context.Context, stream string, readChannel chan<- *redis.XStream, keep bool, wg *sync.WaitGroup, logFields watermill.LogFields) {
defer wg.Done()
var (
xps []redis.XPendingExt
err error
Expand All @@ -339,6 +361,7 @@ func (s *Subscriber) claim(ctx context.Context, stream string, readChannel chan<
defer func() {
tick.Stop()
close(initCh)
wg.Done()
}()
if !keep { // if not keep, run immediately
initCh <- 1
Expand Down Expand Up @@ -396,16 +419,6 @@ OUTER_LOOP:
)
continue OUTER_LOOP
}

// delete idle consumer
if err = s.client.XGroupDelConsumer(ctx, stream, s.config.ConsumerGroup, xp.Consumer).Err(); err != nil {
s.logger.Error(
"xgroupdelconsumer fail",
err,
logFields.Add(watermill.LogFields{"xp": xp}),
)
continue OUTER_LOOP
}
if len(xm) > 0 {
select {
case <-s.closing:
Expand All @@ -426,6 +439,46 @@ OUTER_LOOP:
}
}

func (s *Subscriber) checkConsumers(ctx context.Context, stream string, wg *sync.WaitGroup, logFields watermill.LogFields) {
tick := time.NewTicker(s.config.CheckConsumersInterval)
defer func() {
tick.Stop()
wg.Done()
}()

for {
select {
case <-s.closing:
return
case <-ctx.Done():
return
case <-tick.C:
}
xics, err := s.client.XInfoConsumers(ctx, stream, s.config.ConsumerGroup).Result()
if err != nil {
s.logger.Error(
"xinfoconsumers failed",
err,
logFields,
)
}
for _, xic := range xics {
if xic.Idle < s.config.ConsumerTimeout {
continue
}
if xic.Pending == 0 {
if err = s.client.XGroupDelConsumer(ctx, stream, s.config.ConsumerGroup, xic.Name).Err(); err != nil {
s.logger.Error(
"xgroupdelconsumer failed",
err,
logFields,
)
}
}
}
}
}

func (s *Subscriber) createMessageHandler(output chan *message.Message) messageHandler {
return messageHandler{
outputChannel: output,
Expand Down