diff --git a/pulsar/internal/semaphore.go b/pulsar/internal/semaphore.go index a34497f865..6a67cc3550 100644 --- a/pulsar/internal/semaphore.go +++ b/pulsar/internal/semaphore.go @@ -18,6 +18,7 @@ package internal import ( + "context" "sync/atomic" log "github.com/sirupsen/logrus" @@ -26,7 +27,7 @@ import ( type Semaphore interface { // Acquire a permit, if one is available and returns immediately, // reducing the number of available permits by one. - Acquire() + Acquire(ctx context.Context) bool // Try to acquire a permit. The method will return immediately // with a `true` if it was possible to acquire a permit and @@ -63,14 +64,21 @@ func NewSemaphore(maxPermits int32) Semaphore { } } -func (s *semaphore) Acquire() { +func (s *semaphore) Acquire(ctx context.Context) bool { permits := atomic.AddInt32(&s.permits, 1) if permits <= s.maxPermits { - return + return true } // Block on the channel until a new permit is available - <-s.ch + // or the context expires + select { + case <-s.ch: + return true + case <-ctx.Done(): + atomic.AddInt32(&s.permits, -1) + return false + } } func (s *semaphore) TryAcquire() bool { diff --git a/pulsar/internal/semaphore_test.go b/pulsar/internal/semaphore_test.go index 0de69fcf66..b692d6864f 100644 --- a/pulsar/internal/semaphore_test.go +++ b/pulsar/internal/semaphore_test.go @@ -18,6 +18,7 @@ package internal import ( + "context" "sync" "testing" "time" @@ -35,7 +36,7 @@ func TestSemaphore(t *testing.T) { for i := 0; i < n; i++ { go func() { - s.Acquire() + assert.True(t, s.Acquire(context.Background())) time.Sleep(100 * time.Millisecond) s.Release() wg.Done() @@ -48,7 +49,7 @@ func TestSemaphore(t *testing.T) { func TestSemaphore_TryAcquire(t *testing.T) { s := NewSemaphore(1) - s.Acquire() + assert.True(t, s.Acquire(context.Background())) assert.False(t, s.TryAcquire()) @@ -58,3 +59,18 @@ func TestSemaphore_TryAcquire(t *testing.T) { assert.False(t, s.TryAcquire()) s.Release() } + +func TestSemaphore_ContextExpire(t *testing.T) { + s := NewSemaphore(1) + + assert.True(t, s.Acquire(context.Background())) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + assert.False(t, s.Acquire(ctx)) + + assert.False(t, s.TryAcquire()) + s.Release() + + assert.True(t, s.TryAcquire()) +} diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go index 7dd176c1bf..002d261cfd 100644 --- a/pulsar/producer_partition.go +++ b/pulsar/producer_partition.go @@ -48,6 +48,7 @@ var ( errFailAddToBatch = newError(AddToBatchFailed, "message add to batch failed") errSendTimeout = newError(TimeoutError, "message send timeout") errSendQueueIsFull = newError(ProducerQueueIsFull, "producer send queue is full") + errContextExpired = newError(TimeoutError, "message send context expired") errMessageTooLarge = newError(MessageTooBig, "message size exceeds MaxMessageSize") buffersPool sync.Pool @@ -658,7 +659,10 @@ func (p *partitionProducer) internalSendAsync(ctx context.Context, msg *Producer return } } else { - p.publishSemaphore.Acquire() + if !p.publishSemaphore.Acquire(ctx) { + callback(nil, msg, errContextExpired) + return + } } p.metrics.MessagesPending.Inc() diff --git a/pulsar/producer_test.go b/pulsar/producer_test.go index 4d62cac169..7c3dbd76bb 100644 --- a/pulsar/producer_test.go +++ b/pulsar/producer_test.go @@ -930,6 +930,72 @@ func TestSendTimeout(t *testing.T) { makeHTTPCall(t, http.MethodDelete, quotaURL, "") } +func TestSendContextExpired(t *testing.T) { + quotaURL := adminURL + "/admin/v2/namespaces/public/default/backlogQuota" + quotaFmt := `{"limit": "%d", "policy": "producer_request_hold"}` + makeHTTPCall(t, http.MethodPost, quotaURL, fmt.Sprintf(quotaFmt, 1024)) + + client, err := NewClient(ClientOptions{ + URL: serviceURL, + }) + assert.NoError(t, err) + defer client.Close() + + topicName := newTopicName() + consumer, err := client.Subscribe(ConsumerOptions{ + Topic: topicName, + SubscriptionName: "send_context_expired_sub", + }) + assert.Nil(t, err) + defer consumer.Close() // subscribe but do nothing + + noRetry := uint(0) + producer, err := client.CreateProducer(ProducerOptions{ + Topic: topicName, + MaxPendingMessages: 1, + SendTimeout: 2 * time.Second, + MaxReconnectToBroker: &noRetry, + }) + assert.Nil(t, err) + defer producer.Close() + + // first send completes and fills the available backlog + id, err := producer.Send(context.Background(), &ProducerMessage{ + Payload: make([]byte, 1024), + }) + assert.Nil(t, err) + assert.NotNil(t, id) + + // waiting for the backlog check + time.Sleep((5 + 1) * time.Second) + + // next publish will not complete due to the backlog quota being full; + // this consumes the only available MaxPendingMessages permit + wg := sync.WaitGroup{} + wg.Add(1) + producer.SendAsync(context.Background(), &ProducerMessage{ + Payload: make([]byte, 1024), + }, func(_ MessageID, _ *ProducerMessage, _ error) { + // we're not interested in the result of this send, but we don't + // want to exit this test case until it completes + wg.Done() + }) + + // final publish will block waiting for a send permit to become available + // then fail when the ctx times out + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + id, err = producer.Send(ctx, &ProducerMessage{ + Payload: make([]byte, 1024), + }) + assert.NotNil(t, err) + assert.Nil(t, id) + + wg.Wait() + + makeHTTPCall(t, http.MethodDelete, quotaURL, "") +} + type noopProduceInterceptor struct{} func (noopProduceInterceptor) BeforeSend(producer Producer, message *ProducerMessage) {}