Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 135 additions & 81 deletions shard_consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package kinsumer

import (
"context"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -25,6 +26,13 @@ const (
errorSleepDuration = 1 * time.Second
)

type consumeEvent struct {
Records []*kinesis.Record
Lag time.Duration
SequenceNumber string
Finished bool
}

// getShardIterator gets a shard iterator after the last sequence number we read or at the start of the stream
func getShardIterator(k kinesisiface.KinesisAPI, streamName string, shardID string, sequenceNumber string) (string, error) {
shardIteratorType := kinesis.ShardIteratorTypeAfterSequenceNumber
Expand Down Expand Up @@ -105,6 +113,9 @@ func (k *Kinsumer) captureShard(shardID string) (*checkpointer, error) {
// TODO: There are no tests for this file. Not sure how to even unit test this.
func (k *Kinsumer) consume(shardID string) {
defer k.waitGroup.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stop := k.stop

// commitTicker is used to periodically commit, so that we don't hammer dynamo every time
// a shard wants to be check pointed
Expand All @@ -124,10 +135,6 @@ func (k *Kinsumer) consume(shardID string) {
return
}

sequenceNumber := checkpointer.sequenceNumber

// finished means we have reached the end of the shard but haven't necessarily processed/committed everything
finished := false
// Make sure we release the shard when we are done.
defer func() {
innerErr := checkpointer.release()
Expand All @@ -137,32 +144,16 @@ func (k *Kinsumer) consume(shardID string) {
}
}()

// Get the starting shard iterator
iterator, err := getShardIterator(k.kinesis, k.streamName, shardID, sequenceNumber)
if err != nil {
k.shardErrors <- shardConsumerError{shardID: shardID, action: "getShardIterator", err: err}
return
}

// no throttle on the first request.
nextThrottle := time.After(0)

retryCount := 0

var lastSeqNum string
mainloop:
for {
// We have reached the end of the shard's data. Set Finished in dynamo and stop processing.
if iterator == "" && !finished {
checkpointer.finish(lastSeqNum)
finished = true
}
commitDone := make(chan struct{})
go func() {
defer close(commitDone)
for {
select {
case <-ctx.Done():
return
case <-commitTicker.C:
}

// Handle async actions, and throttle requests to keep kinesis happy
select {
case <-k.stop:
return
case <-commitTicker.C:
finishCommitted, err := checkpointer.commit()
if err != nil {
k.shardErrors <- shardConsumerError{shardID: shardID, action: "checkpointer.commit", err: err}
Expand All @@ -171,71 +162,134 @@ mainloop:
if finishCommitted {
return
}
// Go back to waiting for a throttle/stop.
continue mainloop
case <-nextThrottle:
}
}()

// Reset the nextThrottle
nextThrottle = time.After(k.config.throttleDelay)
sequenceNumber := checkpointer.sequenceNumber

if finished {
continue mainloop
}
evtCh := k.consumePolling(ctx, shardID, sequenceNumber)

// Get records from kinesis
records, next, lag, err := getRecords(k.kinesis, iterator)
mainloop:
// Continue processing until both the checkpointer and event goroutines are done
// Signal to those goroutines to stop when either of them stop, or when a stop request comes in
for commitDone != nil || evtCh != nil {
select {
case <-stop:
cancel()
stop = nil
case <-commitDone:
cancel()
commitDone = nil
case e, ok := <-evtCh:
if !ok {
cancel()
evtCh = nil
continue mainloop
}
if e.Finished {
checkpointer.finish(e.SequenceNumber)
}

if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
k.config.logger.Log("Got error: %s (%s) retry count is %d / %d", awsErr.Message(), awsErr.OrigErr(), retryCount, maxErrorRetries)
if retryCount < maxErrorRetries {
retryCount++

// casting retryCount here to time.Duration purely for the multiplication, there is
// no meaning to retryCount nanoseconds
time.Sleep(errorSleepDuration * time.Duration(retryCount))
continue mainloop
// Put all the records we got onto the channel
k.config.stats.EventsFromKinesis(len(e.Records), shardID, e.Lag)
retrievedAt := time.Now()
for _, record := range e.Records {
r := &consumedRecord{
record: record,
checkpointer: checkpointer,
retrievedAt: retrievedAt,
}
// Do a blocking send, unless we're stopping in which case do a non-blocking send
// We have to check both ctx.Done and stop since we don't know if stop has been processed yet
select {
case <-ctx.Done():
case <-stop:
case k.records <- r:
}
}
k.shardErrors <- shardConsumerError{shardID: shardID, action: "getRecords", err: err}
return
}
retryCount = 0
}
}

// Put all the records we got onto the channel
k.config.stats.EventsFromKinesis(len(records), shardID, lag)
if len(records) > 0 {
retrievedAt := time.Now()
for _, record := range records {
RecordLoop:
// Loop until we stop or the record is consumed, checkpointing if necessary.
for {
select {
case <-commitTicker.C:
finishCommitted, err := checkpointer.commit()
if err != nil {
k.shardErrors <- shardConsumerError{shardID: shardID, action: "checkpointer.commit", err: err}
return
func (k *Kinsumer) consumePolling(ctx context.Context, shardID string, sequenceNumber string) <-chan consumeEvent {
ch := make(chan consumeEvent)

go func() {
defer close(ch)

// no throttle on the first request.
nextThrottle := time.After(0)
retryCount := 0

for {
// Get the starting shard iterator
iterator, err := getShardIterator(k.kinesis, k.streamName, shardID, sequenceNumber)
if err != nil {
k.shardErrors <- shardConsumerError{shardID: shardID, action: "getShardIterator", err: err}
return
}

getrecordloop:
for {
select {
case <-ctx.Done():
return
case <-nextThrottle:
}

// Reset the nextThrottle
nextThrottle = time.After(k.config.throttleDelay)

// Get records from kinesis
records, next, lag, err := getRecords(k.kinesis, iterator)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Iterators expire after 5 minutes, which can happen if Next() is called too slowly
// Rather than error we can generate a new iterator and recover
if awsErr.Code() == "ExpiredIteratorException" {
break getrecordloop
}
if finishCommitted {
return

k.config.logger.Log("Got error: %s (%s) retry count is %d / %d", awsErr.Message(), awsErr.OrigErr(), retryCount, maxErrorRetries)
if retryCount < maxErrorRetries {
retryCount++

// casting retryCount here to time.Duration purely for the multiplication, there is
// no meaning to retryCount nanoseconds
time.Sleep(errorSleepDuration * time.Duration(retryCount))
continue getrecordloop
}
case <-k.stop:
return
case k.records <- &consumedRecord{
record: record,
checkpointer: checkpointer,
retrievedAt: retrievedAt,
}:
break RecordLoop
}
k.shardErrors <- shardConsumerError{shardID: shardID, action: "getRecords", err: err}
return
}
}

// Update the last sequence number we saw, in case we reached the end of the stream.
lastSeqNum = aws.StringValue(records[len(records)-1].SequenceNumber)
retryCount = 0
iterator = next
if len(records) > 0 {
sequenceNumber = aws.StringValue(records[len(records)-1].SequenceNumber)
}
finished := iterator == ""

ev := consumeEvent{
Records: records,
Lag: lag,
SequenceNumber: sequenceNumber,
Finished: finished,
}

select {
case <-ctx.Done():
return
case ch <- ev:
}

if finished {
return
}
}
}
iterator = next
}
}()

return ch
}