Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,35 @@ func New(client QueueAPI, config *Config) *Worker {

// Start starts the polling and will continue polling till the application is forcibly stopped
func (worker *Worker) Start(ctx context.Context, h Handler) {
messages := make(chan *sqs.Message, worker.Config.MaxNumberOfMessage)
var wg sync.WaitGroup

go worker.startPolling(ctx, messages, &wg)

for {
message, ok := <-messages
if !ok {
break
}

go func(m *sqs.Message) {
// launch goroutine
defer wg.Done()
if err := worker.handleMessage(m, h); err != nil {
worker.Log.Error(err.Error())
}
}(message)
}

wg.Wait()
}

func (worker *Worker) startPolling(ctx context.Context, messages chan *sqs.Message, wg *sync.WaitGroup) {
for {
select {
case <-ctx.Done():
log.Println("worker: Stopping polling because a context kill signal was sent")
close(messages)
return
default:
worker.Log.Debug("worker: Start Polling")
Expand All @@ -104,32 +129,17 @@ func (worker *Worker) Start(ctx context.Context, h Handler) {
continue
}
if len(resp.Messages) > 0 {
worker.run(h, resp.Messages)
numMessages := len(resp.Messages)
worker.Log.Info(fmt.Sprintf("worker: Received %d messages", numMessages))
wg.Add(numMessages)
for _, message := range resp.Messages {
messages <- message
}
}
}
}
}

// poll launches goroutine per received message and wait for all message to be processed
func (worker *Worker) run(h Handler, messages []*sqs.Message) {
numMessages := len(messages)
worker.Log.Info(fmt.Sprintf("worker: Received %d messages", numMessages))

var wg sync.WaitGroup
wg.Add(numMessages)
for i := range messages {
go func(m *sqs.Message) {
// launch goroutine
defer wg.Done()
if err := worker.handleMessage(m, h); err != nil {
worker.Log.Error(err.Error())
}
}(messages[i])
}

wg.Wait()
}

func (worker *Worker) handleMessage(m *sqs.Message, h Handler) error {
var err error
err = h.HandleMessage(m)
Expand Down
67 changes: 41 additions & 26 deletions worker/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import (
"context"
"encoding/json"
"fmt"
"math/rand"
"strconv"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sqs"
Expand All @@ -14,8 +15,10 @@ import (
)

type mockedSqsClient struct {
Config *aws.Config
Response sqs.ReceiveMessageOutput
Config *aws.Config
Messages []*sqs.Message
ReceiveIndex int
Cancel context.CancelFunc
QueueAPI
mock.Mock
}
Expand All @@ -29,33 +32,46 @@ func (c *mockedSqsClient) GetQueueUrl(urlInput *sqs.GetQueueUrlInput) (*sqs.GetQ
func (c *mockedSqsClient) ReceiveMessage(input *sqs.ReceiveMessageInput) (*sqs.ReceiveMessageOutput, error) {
c.Called(input)

return &c.Response, nil
startRange := c.ReceiveIndex
endRange := startRange + rand.Intn(9)
if endRange > totalNumberOfMessages {
endRange = totalNumberOfMessages
c.Cancel()
}

messages := c.Messages[startRange:endRange]
c.ReceiveIndex += endRange - startRange

return &sqs.ReceiveMessageOutput{Messages: messages}, nil
}

func (c *mockedSqsClient) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) {
c.Called(input)
c.Response = sqs.ReceiveMessageOutput{}

return &sqs.DeleteMessageOutput{}, nil
}

type mockedHandler struct {
mock.Mock
}

func (mh *mockedHandler) HandleMessage(foo string, qux string) {
mh.Called(foo, qux)
func (mh *mockedHandler) HandleMessage(receiptHandle, foo, qux string) {
mh.Called(receiptHandle, foo, qux)
}

type sqsEvent struct {
Foo string `json:"foo"`
Qux string `json:"qux"`
}

const maxNumberOfMessages = 1984
const waitTimeSecond = 1337
const maxNumberOfMessages = 9
const waitTimeSecond = 19

const totalNumberOfMessages = 100

func TestStart(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

region := "eu-west-1"
awsConfig := &aws.Config{Region: &region}
workerConfig := &Config{
Expand All @@ -65,23 +81,24 @@ func TestStart(t *testing.T) {
}

clientParams := buildClientParams()
sqsMessage := &sqs.Message{Body: aws.String(`{ "foo": "bar", "qux": "baz" }`)}
sqsResponse := sqs.ReceiveMessageOutput{Messages: []*sqs.Message{sqsMessage}}
client := &mockedSqsClient{Response: sqsResponse, Config: awsConfig}
deleteInput := &sqs.DeleteMessageInput{QueueUrl: clientParams.QueueUrl}
sqsMessages := make([]*sqs.Message, totalNumberOfMessages)
for i := 0; i < totalNumberOfMessages; i++ {
sqsMessages[i] = &sqs.Message{
Body: aws.String(`{ "foo": "bar", "qux": "baz" }`),
ReceiptHandle: aws.String(strconv.Itoa(i)),
}
}
client := &mockedSqsClient{Messages: sqsMessages, Config: awsConfig, Cancel: cancel}

worker := New(client, workerConfig)

ctx, cancel := contextAndCancel()
defer cancel()

handler := new(mockedHandler)
handlerFunc := HandlerFunc(func(msg *sqs.Message) (err error) {
event := &sqsEvent{}

json.Unmarshal([]byte(aws.StringValue(msg.Body)), event)

handler.HandleMessage(event.Foo, event.Qux)
handler.HandleMessage(*msg.ReceiptHandle, event.Foo, event.Qux)

return
})
Expand All @@ -107,8 +124,12 @@ func TestStart(t *testing.T) {

t.Run("the worker successfully processes a message", func(t *testing.T) {
client.On("ReceiveMessage", clientParams).Return()
client.On("DeleteMessage", deleteInput).Return()
handler.On("HandleMessage", "bar", "baz").Return().Once()
for i := 0; i < totalNumberOfMessages; i++ {
receiptHandle := strconv.Itoa(i)
client.On("DeleteMessage",
&sqs.DeleteMessageInput{QueueUrl: clientParams.QueueUrl, ReceiptHandle: &receiptHandle}).Return().Once()
handler.On("HandleMessage", receiptHandle, "bar", "baz").Return().Once()
}

worker.Start(ctx, handlerFunc)

Expand All @@ -117,12 +138,6 @@ func TestStart(t *testing.T) {
})
}

func contextAndCancel() (context.Context, context.CancelFunc) {
delay := time.Now().Add(1 * time.Millisecond)

return context.WithDeadline(context.Background(), delay)
}

func buildClientParams() *sqs.ReceiveMessageInput {
url := aws.String("https://sqs.eu-west-1.amazonaws.com/123456789/my-sqs-queue")

Expand Down