Skip to content

Commit 21747c9

Browse files
Update RateLimiter queues on cancellation (#64825)
1 parent acbd20e commit 21747c9

File tree

5 files changed

+121
-12
lines changed

5 files changed

+121
-12
lines changed

src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/ConcurrencyLimiter.cs

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,13 @@ protected override ValueTask<RateLimitLease> WaitAsyncCore(int permitCount, Canc
118118
}
119119
}
120120

121-
TaskCompletionSource<RateLimitLease> tcs = new TaskCompletionSource<RateLimitLease>(TaskCreationOptions.RunContinuationsAsynchronously);
121+
CancelQueueState tcs = new CancelQueueState(permitCount, this, cancellationToken);
122122
CancellationTokenRegistration ctr = default;
123123
if (cancellationToken.CanBeCanceled)
124124
{
125125
ctr = cancellationToken.Register(static obj =>
126126
{
127-
((TaskCompletionSource<RateLimitLease>)obj!).TrySetException(new OperationCanceledException());
127+
((CancelQueueState)obj!).TrySetCanceled();
128128
}, tcs);
129129
}
130130

@@ -194,7 +194,6 @@ private void Release(int releaseCount)
194194

195195
_permitCount -= nextPendingRequest.Count;
196196
_queueCount -= nextPendingRequest.Count;
197-
Debug.Assert(_queueCount >= 0);
198197
Debug.Assert(_permitCount >= 0);
199198

200199
ConcurrencyLease lease = nextPendingRequest.Count == 0 ? SuccessfulLease : new ConcurrencyLease(true, this, nextPendingRequest.Count);
@@ -203,8 +202,11 @@ private void Release(int releaseCount)
203202
{
204203
// Queued item was canceled so add count back
205204
_permitCount += nextPendingRequest.Count;
205+
// Updating queue count is handled by the cancellation code
206+
_queueCount += nextPendingRequest.Count;
206207
}
207208
nextPendingRequest.CancellationTokenRegistration.Dispose();
209+
Debug.Assert(_queueCount >= 0);
208210
}
209211
else
210212
{
@@ -319,5 +321,33 @@ public RequestRegistration(int requestedCount, TaskCompletionSource<RateLimitLea
319321

320322
public CancellationTokenRegistration CancellationTokenRegistration { get; }
321323
}
324+
325+
private sealed class CancelQueueState : TaskCompletionSource<RateLimitLease>
326+
{
327+
private readonly int _permitCount;
328+
private readonly ConcurrencyLimiter _limiter;
329+
private readonly CancellationToken _cancellationToken;
330+
331+
public CancelQueueState(int permitCount, ConcurrencyLimiter limiter, CancellationToken cancellationToken)
332+
: base(TaskCreationOptions.RunContinuationsAsynchronously)
333+
{
334+
_permitCount = permitCount;
335+
_limiter = limiter;
336+
_cancellationToken = cancellationToken;
337+
}
338+
339+
public new bool TrySetCanceled()
340+
{
341+
if (TrySetCanceled(_cancellationToken))
342+
{
343+
lock (_limiter.Lock)
344+
{
345+
_limiter._queueCount -= _permitCount;
346+
}
347+
return true;
348+
}
349+
return false;
350+
}
351+
}
322352
}
323353
}

src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/TokenBucketRateLimiter.cs

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,13 @@ protected override ValueTask<RateLimitLease> WaitAsyncCore(int tokenCount, Cance
124124
}
125125
}
126126

127-
TaskCompletionSource<RateLimitLease> tcs = new TaskCompletionSource<RateLimitLease>(TaskCreationOptions.RunContinuationsAsynchronously);
128-
127+
CancelQueueState tcs = new CancelQueueState(tokenCount, this, cancellationToken);
129128
CancellationTokenRegistration ctr = default;
130129
if (cancellationToken.CanBeCanceled)
131130
{
132131
ctr = cancellationToken.Register(static obj =>
133132
{
134-
((TaskCompletionSource<RateLimitLease>)obj!).TrySetException(new OperationCanceledException());
133+
((CancelQueueState)obj!).TrySetCanceled();
135134
}, tcs);
136135
}
137136

@@ -140,7 +139,6 @@ protected override ValueTask<RateLimitLease> WaitAsyncCore(int tokenCount, Cance
140139
_queueCount += tokenCount;
141140
Debug.Assert(_queueCount <= _options.QueueLimit);
142141

143-
// handle cancellation
144142
return new ValueTask<RateLimitLease>(registration.Tcs.Task);
145143
}
146144
}
@@ -276,15 +274,17 @@ private void ReplenishInternal(uint nowTicks)
276274

277275
_queueCount -= nextPendingRequest.Count;
278276
_tokenCount -= nextPendingRequest.Count;
279-
Debug.Assert(_queueCount >= 0);
280277
Debug.Assert(_tokenCount >= 0);
281278

282279
if (!nextPendingRequest.Tcs.TrySetResult(SuccessfulLease))
283280
{
284281
// Queued item was canceled so add count back
285282
_tokenCount += nextPendingRequest.Count;
283+
// Updating queue count is handled by the cancellation code
284+
_queueCount += nextPendingRequest.Count;
286285
}
287286
nextPendingRequest.CancellationTokenRegistration.Dispose();
287+
Debug.Assert(_queueCount >= 0);
288288
}
289289
else
290290
{
@@ -380,7 +380,34 @@ public RequestRegistration(int tokenCount, TaskCompletionSource<RateLimitLease>
380380
public TaskCompletionSource<RateLimitLease> Tcs { get; }
381381

382382
public CancellationTokenRegistration CancellationTokenRegistration { get; }
383+
}
384+
385+
private sealed class CancelQueueState : TaskCompletionSource<RateLimitLease>
386+
{
387+
private readonly int _tokenCount;
388+
private readonly TokenBucketRateLimiter _limiter;
389+
private readonly CancellationToken _cancellationToken;
390+
391+
public CancelQueueState(int tokenCount, TokenBucketRateLimiter limiter, CancellationToken cancellationToken)
392+
: base(TaskCreationOptions.RunContinuationsAsynchronously)
393+
{
394+
_tokenCount = tokenCount;
395+
_limiter = limiter;
396+
_cancellationToken = cancellationToken;
397+
}
383398

399+
public new bool TrySetCanceled()
400+
{
401+
if (TrySetCanceled(_cancellationToken))
402+
{
403+
lock (_limiter.Lock)
404+
{
405+
_limiter._queueCount -= _tokenCount;
406+
}
407+
return true;
408+
}
409+
return false;
410+
}
384411
}
385412
}
386413
}

src/libraries/System.Threading.RateLimiting/tests/BaseRateLimiterTests.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ public abstract class BaseRateLimiterTests
8080
[Fact]
8181
public abstract Task CanCancelWaitAsyncBeforeQueuing();
8282

83+
[Fact]
84+
public abstract Task CancelUpdatesQueueLimit();
85+
8386
[Fact]
8487
public abstract Task CanAcquireResourcesWithAcquireWithQueuedItemsIfNewestFirst();
8588

src/libraries/System.Threading.RateLimiting/tests/ConcurrencyLimiterTests.cs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,8 @@ public override async Task CanCancelWaitAsyncAfterQueuing()
401401
var wait = limiter.WaitAsync(1, cts.Token);
402402

403403
cts.Cancel();
404-
await Assert.ThrowsAsync<OperationCanceledException>(() => wait.AsTask());
404+
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => wait.AsTask());
405+
Assert.Equal(cts.Token, ex.CancellationToken);
405406

406407
lease.Dispose();
407408

@@ -418,13 +419,36 @@ public override async Task CanCancelWaitAsyncBeforeQueuing()
418419
var cts = new CancellationTokenSource();
419420
cts.Cancel();
420421

421-
await Assert.ThrowsAsync<TaskCanceledException>(() => limiter.WaitAsync(1, cts.Token).AsTask());
422+
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => limiter.WaitAsync(1, cts.Token).AsTask());
423+
Assert.Equal(cts.Token, ex.CancellationToken);
422424

423425
lease.Dispose();
424426

425427
Assert.Equal(1, limiter.GetAvailablePermits());
426428
}
427429

430+
[Fact]
431+
public override async Task CancelUpdatesQueueLimit()
432+
{
433+
var limiter = new ConcurrencyLimiter(new ConcurrencyLimiterOptions(1, QueueProcessingOrder.OldestFirst, 1));
434+
var lease = limiter.Acquire(1);
435+
Assert.True(lease.IsAcquired);
436+
437+
var cts = new CancellationTokenSource();
438+
var wait = limiter.WaitAsync(1, cts.Token);
439+
440+
cts.Cancel();
441+
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => wait.AsTask());
442+
Assert.Equal(cts.Token, ex.CancellationToken);
443+
444+
wait = limiter.WaitAsync(1);
445+
Assert.False(wait.IsCompleted);
446+
447+
lease.Dispose();
448+
lease = await wait;
449+
Assert.True(lease.IsAcquired);
450+
}
451+
428452
[Fact]
429453
public override void NoMetadataOnAcquiredLease()
430454
{

src/libraries/System.Threading.RateLimiting/tests/TokenBucketRateLimiterTests.cs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ public override async Task CanCancelWaitAsyncAfterQueuing()
354354
var wait = limiter.WaitAsync(1, cts.Token);
355355

356356
cts.Cancel();
357-
await Assert.ThrowsAsync<OperationCanceledException>(() => wait.AsTask());
357+
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => wait.AsTask());
358+
Assert.Equal(cts.Token, ex.CancellationToken);
358359

359360
lease.Dispose();
360361
Assert.True(limiter.TryReplenish());
@@ -373,14 +374,38 @@ public override async Task CanCancelWaitAsyncBeforeQueuing()
373374
var cts = new CancellationTokenSource();
374375
cts.Cancel();
375376

376-
await Assert.ThrowsAsync<TaskCanceledException>(() => limiter.WaitAsync(1, cts.Token).AsTask());
377+
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => limiter.WaitAsync(1, cts.Token).AsTask());
378+
Assert.Equal(cts.Token, ex.CancellationToken);
377379

378380
lease.Dispose();
379381
Assert.True(limiter.TryReplenish());
380382

381383
Assert.Equal(1, limiter.GetAvailablePermits());
382384
}
383385

386+
[Fact]
387+
public override async Task CancelUpdatesQueueLimit()
388+
{
389+
var limiter = new TokenBucketRateLimiter(new TokenBucketRateLimiterOptions(1, QueueProcessingOrder.OldestFirst, 1,
390+
TimeSpan.Zero, 1, autoReplenishment: false));
391+
var lease = limiter.Acquire(1);
392+
Assert.True(lease.IsAcquired);
393+
394+
var cts = new CancellationTokenSource();
395+
var wait = limiter.WaitAsync(1, cts.Token);
396+
397+
cts.Cancel();
398+
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => wait.AsTask());
399+
Assert.Equal(cts.Token, ex.CancellationToken);
400+
401+
wait = limiter.WaitAsync(1);
402+
Assert.False(wait.IsCompleted);
403+
404+
limiter.TryReplenish();
405+
lease = await wait;
406+
Assert.True(lease.IsAcquired);
407+
}
408+
384409
[Fact]
385410
public override void NoMetadataOnAcquiredLease()
386411
{

0 commit comments

Comments
 (0)