Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 164 additions & 34 deletions pkg/redisstream/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package redisstream

import (
"context"
"math/rand"
"strconv"
"sync"
"testing"
"time"

Expand All @@ -12,6 +14,7 @@ import (

"github.com/pkg/errors"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -29,16 +32,18 @@ func redisClient() (redis.UniversalClient, error) {
return client, nil
}

func redisClientOrFail(t *testing.T) redis.UniversalClient {
client, err := redisClient()
require.NoError(t, err)
return client
}

func newPubSub(t *testing.T, subConfig *SubscriberConfig) (message.Publisher, message.Subscriber) {
logger := watermill.NewStdLogger(true, false)

pubClient, err := redisClient()
require.NoError(t, err)

publisher, err := NewPublisher(
PublisherConfig{
Client: pubClient,
Marshaller: &DefaultMarshallerUnmarshaller{},
Client: redisClientOrFail(t),
},
watermill.NewStdLogger(false, false),
)
Expand All @@ -55,12 +60,8 @@ func createPubSub(t *testing.T) (message.Publisher, message.Subscriber) {
}

func createPubSubWithConsumerGroup(t *testing.T, consumerGroup string) (message.Publisher, message.Subscriber) {
subClient, err := redisClient()
require.NoError(t, err)

return newPubSub(t, &SubscriberConfig{
Client: subClient,
Unmarshaller: &DefaultMarshallerUnmarshaller{},
Client: redisClientOrFail(t),
Consumer: watermill.NewShortUUID(),
ConsumerGroup: consumerGroup,
BlockTime: 10 * time.Millisecond,
Expand All @@ -86,15 +87,9 @@ func TestPublishSubscribe(t *testing.T) {
func TestSubscriber(t *testing.T) {
topic := watermill.NewShortUUID()

pubClient, err := redisClient()
require.NoError(t, err)
subClient, err := redisClient()
require.NoError(t, err)

subscriber, err := NewSubscriber(
SubscriberConfig{
Client: subClient,
Unmarshaller: &DefaultMarshallerUnmarshaller{},
Client: redisClientOrFail(t),
Consumer: watermill.NewShortUUID(),
ConsumerGroup: watermill.NewShortUUID(),
},
Expand All @@ -106,8 +101,7 @@ func TestSubscriber(t *testing.T) {

publisher, err := NewPublisher(
PublisherConfig{
Client: pubClient,
Marshaller: &DefaultMarshallerUnmarshaller{},
Client: redisClientOrFail(t),
},
watermill.NewStdLogger(false, false),
)
Expand Down Expand Up @@ -138,17 +132,9 @@ func TestSubscriber(t *testing.T) {
func TestFanOut(t *testing.T) {
topic := watermill.NewShortUUID()

fanOutPubClient, err := redisClient()
require.NoError(t, err)
fanOutSubClient1, err := redisClient()
require.NoError(t, err)
fanOutSubClient2, err := redisClient()
require.NoError(t, err)

subscriber1, err := NewSubscriber(
SubscriberConfig{
Client: fanOutSubClient1,
Unmarshaller: &DefaultMarshallerUnmarshaller{},
Client: redisClientOrFail(t),
Consumer: watermill.NewShortUUID(),
ConsumerGroup: "",
},
Expand All @@ -158,8 +144,7 @@ func TestFanOut(t *testing.T) {

subscriber2, err := NewSubscriber(
SubscriberConfig{
Client: fanOutSubClient2,
Unmarshaller: &DefaultMarshallerUnmarshaller{},
Client: redisClientOrFail(t),
Consumer: watermill.NewShortUUID(),
ConsumerGroup: "",
},
Expand All @@ -169,8 +154,7 @@ func TestFanOut(t *testing.T) {

publisher, err := NewPublisher(
PublisherConfig{
Client: fanOutPubClient,
Marshaller: &DefaultMarshallerUnmarshaller{},
Client: redisClientOrFail(t),
},
watermill.NewStdLogger(false, false),
)
Expand All @@ -196,7 +180,7 @@ func TestFanOut(t *testing.T) {
t.Fatal("msg nil")
}
t.Logf("subscriber 1: %v %v %v", msg.UUID, msg.Metadata, string(msg.Payload))
require.Equal(t, string(msg.Payload), ("test" + strconv.Itoa(i)))
require.Equal(t, string(msg.Payload), "test"+strconv.Itoa(i))
msg.Ack()
}
for i := 10; i < 50; i++ {
Expand All @@ -205,11 +189,157 @@ func TestFanOut(t *testing.T) {
t.Fatal("msg nil")
}
t.Logf("subscriber 2: %v %v %v", msg.UUID, msg.Metadata, string(msg.Payload))
require.Equal(t, string(msg.Payload), ("test" + strconv.Itoa(i)))
require.Equal(t, string(msg.Payload), "test"+strconv.Itoa(i))
msg.Ack()
}

require.NoError(t, publisher.Close())
require.NoError(t, subscriber1.Close())
require.NoError(t, subscriber2.Close())
}

func TestClaimIdle(t *testing.T) {
// should be long enough to be robust even for CI boxes
testInterval := 250 * time.Millisecond

topic := watermill.NewShortUUID()
consumerGroup := watermill.NewShortUUID()
testLogger := watermill.NewStdLogger(true, false)

router, err := message.NewRouter(message.RouterConfig{
CloseTimeout: testInterval,
}, testLogger)
require.NoError(t, err)

type messageWithMeta struct {
msgID int
subscriberID int
}

receivedCh := make(chan *messageWithMeta)

// let's start a few subscribers; each will wait between 3 and 5 intervals every time
// it receives a message
nSubscribers := 20
seen := make(map[string]map[string]bool)
var seenLock sync.Mutex
for subscriberID := 0; subscriberID < nSubscribers; subscriberID++ {
// need to assign to a variable local to the loop because of how golang
// handles loop variables in function literals
subID := subscriberID

suscriber, err := NewSubscriber(
SubscriberConfig{
Client: redisClientOrFail(t),
Consumer: strconv.Itoa(subID),
ConsumerGroup: consumerGroup,
ClaimInterval: testInterval,
MaxIdleTime: 2 * testInterval,
// we're only going to claim messages for consumers with odd IDs
ShouldClaimPendingMessage: func(ext redis.XPendingExt) bool {
idleConsumerID, err := strconv.Atoi(ext.Consumer)
require.NoError(t, err)

if idleConsumerID%2 == 0 {
return false
}

seenLock.Lock()
defer seenLock.Unlock()

if seen[ext.ID] == nil {
seen[ext.ID] = make(map[string]bool)
}
if seen[ext.ID][ext.Consumer] {
return false
}
seen[ext.ID][ext.Consumer] = true
return true
},
},
testLogger,
)
require.NoError(t, err)

router.AddNoPublisherHandler(
strconv.Itoa(subID),
topic,
suscriber,
func(msg *message.Message) error {
msgID, err := strconv.Atoi(string(msg.Payload))
require.NoError(t, err)

receivedCh <- &messageWithMeta{
msgID: msgID,
subscriberID: subID,
}
sleepInterval := (3 + 2*rand.Float64()) * float64(testInterval)
time.Sleep(time.Duration(sleepInterval))

return nil
},
)
}

runCtx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
require.NoError(t, router.Run(runCtx))
}()

// now let's push a few messages
publisher, err := NewPublisher(
PublisherConfig{
Client: redisClientOrFail(t),
},
testLogger,
)
require.NoError(t, err)

nMessages := 100
for msgID := 0; msgID < nMessages; msgID++ {
msg := message.NewMessage(watermill.NewShortUUID(), []byte(strconv.Itoa(msgID)))
require.NoError(t, publisher.Publish(topic, msg))
}

// now let's wait to receive them
receivedByID := make(map[int][]*messageWithMeta)
for len(receivedByID) != nMessages {
select {
case msg := <-receivedCh:
receivedByID[msg.msgID] = append(receivedByID[msg.msgID], msg)
case <-time.After(8 * testInterval):
t.Fatalf("timed out waiting for new messages, only received %d unique messages", len(receivedByID))
}
}

// shut down the router and the subscribers
cancel()
wg.Wait()

// now let's look at what we've received:
// * at least some messages should have been retried
// * for retried messages, there should be at most one consumer with an even ID
nMsgsWithRetries := 0
for _, withSameID := range receivedByID {
require.Greater(t, len(withSameID), 0)
if len(withSameID) == 1 {
// this message was not retried at all
continue
}

nMsgsWithRetries++

nEvenConsumers := 0
for _, msg := range withSameID {
if msg.subscriberID%2 == 0 {
nEvenConsumers++
}
}
assert.LessOrEqual(t, nEvenConsumers, 1)
}

assert.GreaterOrEqual(t, nMsgsWithRetries, 3)
}
Loading