Skip to content

Commit b4e0169

Browse files
authored
Dedup UnboundedChannel and UnboundedPriorityChannel (dotnet#101396)
* Dedup UnboundedChannel and UnboundedPriorityChannel We can use generic specialization to avoid duplicating all the code for the different queue types. This should also make it much simpler to add other queue types in the future. * Address PR feedback
1 parent 56dcfd7 commit b4e0169

File tree

7 files changed

+179
-402
lines changed

7 files changed

+179
-402
lines changed

src/libraries/System.Threading.Channels/src/System.Threading.Channels.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ System.Threading.Channel&lt;T&gt;</PackageDescription>
2525
<Compile Include="System\Threading\Channels\Channel_1.cs" />
2626
<Compile Include="System\Threading\Channels\Channel_2.cs" />
2727
<Compile Include="System\Threading\Channels\IDebugEnumerator.cs" />
28+
<Compile Include="System\Threading\Channels\IUnboundedChannelQueue.cs" />
2829
<Compile Include="System\Threading\Channels\SingleConsumerUnboundedChannel.cs" />
2930
<Compile Include="System\Threading\Channels\UnboundedChannel.cs" />
3031
<Compile Include="$(CommonPath)Internal\Padding.cs" Link="Common\Internal\Padding.cs" />
@@ -44,7 +45,6 @@ System.Threading.Channel&lt;T&gt;</PackageDescription>
4445
<Compile Include="System\Threading\Channels\AsyncOperation.netcoreapp.cs" />
4546
<Compile Include="System\Threading\Channels\Channel.netcoreapp.cs" />
4647
<Compile Include="System\Threading\Channels\ChannelOptions.netcoreapp.cs" />
47-
<Compile Include="System\Threading\Channels\UnboundedPriorityChannel.cs" />
4848
</ItemGroup>
4949

5050
<ItemGroup Condition="'$(TargetFramework)' == '$(NetCoreAppCurrent)'">

src/libraries/System.Threading.Channels/src/System/Threading/Channels/Channel.cs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Collections.Concurrent;
5+
using System.Collections.Generic;
6+
using System.Diagnostics.CodeAnalysis;
7+
48
namespace System.Threading.Channels
59
{
610
/// <summary>Provides static methods for creating channels.</summary>
@@ -9,7 +13,7 @@ public static partial class Channel
913
/// <summary>Creates an unbounded channel usable by any number of readers and writers concurrently.</summary>
1014
/// <returns>The created channel.</returns>
1115
public static Channel<T> CreateUnbounded<T>() =>
12-
new UnboundedChannel<T>(runContinuationsAsynchronously: true);
16+
new UnboundedChannel<T, UnboundedChannelConcurrentQueue<T>>(new(new()), runContinuationsAsynchronously: true);
1317

1418
/// <summary>Creates an unbounded channel subject to the provided options.</summary>
1519
/// <typeparam name="T">Specifies the type of data in the channel.</typeparam>
@@ -27,7 +31,7 @@ public static Channel<T> CreateUnbounded<T>(UnboundedChannelOptions options)
2731
return new SingleConsumerUnboundedChannel<T>(!options.AllowSynchronousContinuations);
2832
}
2933

30-
return new UnboundedChannel<T>(!options.AllowSynchronousContinuations);
34+
return new UnboundedChannel<T, UnboundedChannelConcurrentQueue<T>>(new(new()), !options.AllowSynchronousContinuations);
3135
}
3236

3337
/// <summary>Creates a channel with the specified maximum capacity.</summary>
@@ -71,5 +75,32 @@ public static Channel<T> CreateBounded<T>(BoundedChannelOptions options, Action<
7175

7276
return new BoundedChannel<T>(options.Capacity, options.FullMode, !options.AllowSynchronousContinuations, itemDropped);
7377
}
78+
79+
/// <summary>Provides an <see cref="IUnboundedChannelQueue{T}"/> for a <see cref="ConcurrentQueue{T}"/>.</summary>
80+
private readonly struct UnboundedChannelConcurrentQueue<T>(ConcurrentQueue<T> queue) : IUnboundedChannelQueue<T>
81+
{
82+
private readonly ConcurrentQueue<T> _queue = queue;
83+
84+
/// <inheritdoc/>
85+
public bool IsThreadSafe => true;
86+
87+
/// <inheritdoc/>
88+
public void Enqueue(T item) => _queue.Enqueue(item);
89+
90+
/// <inheritdoc/>
91+
public bool TryDequeue([MaybeNullWhen(false)] out T item) => _queue.TryDequeue(out item);
92+
93+
/// <inheritdoc/>
94+
public bool TryPeek([MaybeNullWhen(false)] out T item) => _queue.TryPeek(out item);
95+
96+
/// <inheritdoc/>
97+
public int Count => _queue.Count;
98+
99+
/// <inheritdoc/>
100+
public bool IsEmpty => _queue.IsEmpty;
101+
102+
/// <inheritdoc/>
103+
public IEnumerator<T> GetEnumerator() => _queue.GetEnumerator();
104+
}
74105
}
75106
}

src/libraries/System.Threading.Channels/src/System/Threading/Channels/Channel.netcoreapp.cs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,22 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Collections.Generic;
5+
using System.Diagnostics.CodeAnalysis;
56

67
namespace System.Threading.Channels
78
{
89
/// <summary>Provides static methods for creating channels.</summary>
910
public static partial class Channel
1011
{
1112
/// <summary>Creates an unbounded prioritized channel usable by any number of readers and writers concurrently.</summary>
13+
/// <typeparam name="T">Specifies the type of data in the channel.</typeparam>
1214
/// <returns>The created channel.</returns>
1315
/// <remarks>
1416
/// <see cref="Comparer{T}.Default"/> is used to determine priority of elements.
1517
/// The next item read from the channel will be the element available in the channel with the lowest priority value.
1618
/// </remarks>
1719
public static Channel<T> CreateUnboundedPrioritized<T>() =>
18-
new UnboundedPrioritizedChannel<T>(runContinuationsAsynchronously: true, comparer: null);
20+
new UnboundedChannel<T, UnboundedChannelPriorityQueue<T>>(new(new()), runContinuationsAsynchronously: true);
1921

2022
/// <summary>Creates an unbounded prioritized channel subject to the provided options.</summary>
2123
/// <typeparam name="T">Specifies the type of data in the channel.</typeparam>
@@ -30,7 +32,45 @@ public static Channel<T> CreateUnboundedPrioritized<T>(UnboundedPrioritizedChann
3032
{
3133
ArgumentNullException.ThrowIfNull(options);
3234

33-
return new UnboundedPrioritizedChannel<T>(!options.AllowSynchronousContinuations, options.Comparer);
35+
return new UnboundedChannel<T, UnboundedChannelPriorityQueue<T>>(new(new(options.Comparer)), !options.AllowSynchronousContinuations);
36+
}
37+
38+
/// <summary>Provides an <see cref="IUnboundedChannelQueue{T}"/> for a <see cref="PriorityQueue{TElement, TPriority}"/>.</summary>
39+
private readonly struct UnboundedChannelPriorityQueue<T>(PriorityQueue<bool, T> queue) : IUnboundedChannelQueue<T>
40+
{
41+
private readonly PriorityQueue<bool, T> _queue = queue;
42+
43+
/// <inheritdoc/>
44+
public bool IsThreadSafe => false;
45+
46+
/// <inheritdoc/>
47+
public void Enqueue(T item) => _queue.Enqueue(true, item);
48+
49+
/// <inheritdoc/>
50+
public bool TryDequeue([MaybeNullWhen(false)] out T item) => _queue.TryDequeue(out _, out item);
51+
52+
/// <inheritdoc/>
53+
public bool TryPeek([MaybeNullWhen(false)] out T item) => _queue.TryPeek(out _, out item);
54+
55+
/// <inheritdoc/>
56+
public int Count => _queue.Count;
57+
58+
/// <inheritdoc/>
59+
public bool IsEmpty => _queue.Count == 0;
60+
61+
/// <inheritdoc/>
62+
public IEnumerator<T> GetEnumerator()
63+
{
64+
List<T> list = [];
65+
foreach ((bool _, T Priority) item in _queue.UnorderedItems)
66+
{
67+
list.Add(item.Priority);
68+
}
69+
70+
list.Sort(_queue.Comparer);
71+
72+
return list.GetEnumerator();
73+
}
3474
}
3575
}
3676
}

src/libraries/System.Threading.Channels/src/System/Threading/Channels/IDebugEnumerator.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ internal interface IDebugEnumerable<T>
1111
IEnumerator<T> GetEnumerator();
1212
}
1313

14-
internal sealed class DebugEnumeratorDebugView<T>
14+
internal class DebugEnumeratorDebugView<T>
1515
{
1616
public DebugEnumeratorDebugView(IDebugEnumerable<T> enumerable)
1717
{
@@ -26,4 +26,6 @@ public DebugEnumeratorDebugView(IDebugEnumerable<T> enumerable)
2626
[DebuggerBrowsable(DebuggerBrowsableState.RootHidden)]
2727
public T[] Items { get; }
2828
}
29+
30+
internal sealed class DebugEnumeratorDebugView<T, TOther>(IDebugEnumerable<T> enumerable) : DebugEnumeratorDebugView<T>(enumerable);
2931
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Collections.Generic;
5+
using System.Diagnostics.CodeAnalysis;
6+
7+
namespace System.Threading.Channels
8+
{
9+
/// <summary>Representation of the queue data structure used by <see cref="UnboundedChannel{T, TQueue}"/>.</summary>
10+
internal interface IUnboundedChannelQueue<T> : IDebugEnumerable<T>
11+
{
12+
/// <summary>Gets whether the other members are safe to use concurrently with each other and themselves.</summary>
13+
bool IsThreadSafe { get; }
14+
15+
/// <summary>Enqueues an item into the queue.</summary>
16+
/// <param name="item">The item to enqueue.</param>
17+
void Enqueue(T item);
18+
19+
/// <summary>Dequeues an item from the queue, if possible.</summary>
20+
/// <param name="item">The dequeued item, or default if the queue was empty.</param>
21+
/// <returns>Whether an item was dequeued.</returns>
22+
bool TryDequeue([MaybeNullWhen(false)] out T item);
23+
24+
/// <summary>Peeks at the next item from the queue that would be dequeued, if possible.</summary>
25+
/// <param name="item">The peeked item, or default if the queue was empty.</param>
26+
/// <returns>Whether an item was peeked.</returns>
27+
bool TryPeek([MaybeNullWhen(false)] out T item);
28+
29+
/// <summary>Gets the number of elements in the queue.</summary>
30+
int Count { get; }
31+
32+
/// <summary>Gets whether the queue is empty.</summary>
33+
bool IsEmpty { get; }
34+
}
35+
}

src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,20 @@
55
using System.Collections.Generic;
66
using System.Diagnostics;
77
using System.Diagnostics.CodeAnalysis;
8+
using System.Runtime.CompilerServices;
89
using System.Threading.Tasks;
910

1011
namespace System.Threading.Channels
1112
{
1213
/// <summary>Provides a buffered channel of unbounded capacity.</summary>
1314
[DebuggerDisplay("Items = {ItemsCountForDebugger}, Closed = {ChannelIsClosedForDebugger}")]
14-
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))]
15-
internal sealed class UnboundedChannel<T> : Channel<T>, IDebugEnumerable<T>
15+
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<,>))]
16+
internal sealed class UnboundedChannel<T, TQueue> : Channel<T>, IDebugEnumerable<T> where TQueue : struct, IUnboundedChannelQueue<T>
1617
{
1718
/// <summary>Task that indicates the channel has completed.</summary>
1819
private readonly TaskCompletionSource _completion;
1920
/// <summary>The items in the channel.</summary>
20-
private readonly ConcurrentQueue<T> _items = new ConcurrentQueue<T>();
21+
private readonly TQueue _items;
2122
/// <summary>Readers blocked reading from the channel.</summary>
2223
private readonly Deque<AsyncOperation<T>> _blockedReaders = new Deque<AsyncOperation<T>>();
2324
/// <summary>Whether to force continuations to be executed asynchronously from producer writes.</summary>
@@ -29,23 +30,24 @@ internal sealed class UnboundedChannel<T> : Channel<T>, IDebugEnumerable<T>
2930
private Exception? _doneWriting;
3031

3132
/// <summary>Initialize the channel.</summary>
32-
internal UnboundedChannel(bool runContinuationsAsynchronously)
33+
internal UnboundedChannel(TQueue items, bool runContinuationsAsynchronously)
3334
{
35+
_items = items;
3436
_runContinuationsAsynchronously = runContinuationsAsynchronously;
3537
_completion = new TaskCompletionSource(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None);
3638
Reader = new UnboundedChannelReader(this);
3739
Writer = new UnboundedChannelWriter(this);
3840
}
3941

4042
[DebuggerDisplay("Items = {Count}")]
41-
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))]
43+
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<,>))]
4244
private sealed class UnboundedChannelReader : ChannelReader<T>, IDebugEnumerable<T>
4345
{
44-
internal readonly UnboundedChannel<T> _parent;
46+
internal readonly UnboundedChannel<T, TQueue> _parent;
4547
private readonly AsyncOperation<T> _readerSingleton;
4648
private readonly AsyncOperation<bool> _waiterSingleton;
4749

48-
internal UnboundedChannelReader(UnboundedChannel<T> parent)
50+
internal UnboundedChannelReader(UnboundedChannel<T, TQueue> parent)
4951
{
5052
_parent = parent;
5153
_readerSingleton = new AsyncOperation<T>(parent._runContinuationsAsynchronously, pooled: true);
@@ -68,8 +70,8 @@ public override ValueTask<T> ReadAsync(CancellationToken cancellationToken)
6870
}
6971

7072
// Dequeue an item if we can.
71-
UnboundedChannel<T> parent = _parent;
72-
if (parent._items.TryDequeue(out T? item))
73+
UnboundedChannel<T, TQueue> parent = _parent;
74+
if (parent._items.IsThreadSafe && parent._items.TryDequeue(out T? item))
7375
{
7476
CompleteIfDone(parent);
7577
return new ValueTask<T>(item);
@@ -112,24 +114,60 @@ public override ValueTask<T> ReadAsync(CancellationToken cancellationToken)
112114

113115
public override bool TryRead([MaybeNullWhen(false)] out T item)
114116
{
115-
UnboundedChannel<T> parent = _parent;
117+
UnboundedChannel<T, TQueue> parent = _parent;
118+
return parent._items.IsThreadSafe ?
119+
LockFree(parent, out item) :
120+
Locked(parent, out item);
116121

117-
// Dequeue an item if we can
118-
if (parent._items.TryDequeue(out item))
122+
static bool LockFree(UnboundedChannel<T, TQueue> parent, [MaybeNullWhen(false)] out T item)
119123
{
120-
CompleteIfDone(parent);
121-
return true;
124+
if (parent._items.TryDequeue(out item))
125+
{
126+
CompleteIfDone(parent);
127+
return true;
128+
}
129+
130+
item = default;
131+
return false;
122132
}
123133

124-
item = default;
125-
return false;
134+
static bool Locked(UnboundedChannel<T, TQueue> parent, [MaybeNullWhen(false)] out T item)
135+
{
136+
lock (parent.SyncObj)
137+
{
138+
if (parent._items.TryDequeue(out item))
139+
{
140+
CompleteIfDone(parent);
141+
return true;
142+
}
143+
}
144+
145+
item = default;
146+
return false;
147+
}
126148
}
127149

128-
public override bool TryPeek([MaybeNullWhen(false)] out T item) =>
129-
_parent._items.TryPeek(out item);
150+
public override bool TryPeek([MaybeNullWhen(false)] out T item)
151+
{
152+
UnboundedChannel<T, TQueue> parent = _parent;
153+
return parent._items.IsThreadSafe ?
154+
parent._items.TryPeek(out item) :
155+
Locked(parent, out item);
156+
157+
// Separated out to keep the try/finally from preventing TryPeek from being inlined
158+
static bool Locked(UnboundedChannel<T, TQueue> parent, [MaybeNullWhen(false)] out T item)
159+
{
160+
lock (parent.SyncObj)
161+
{
162+
return parent._items.TryPeek(out item);
163+
}
164+
}
165+
}
130166

131-
private static void CompleteIfDone(UnboundedChannel<T> parent)
167+
private static void CompleteIfDone(UnboundedChannel<T, TQueue> parent)
132168
{
169+
Debug.Assert(parent._items.IsThreadSafe || Monitor.IsEntered(parent.SyncObj));
170+
133171
if (parent._doneWriting != null && parent._items.IsEmpty)
134172
{
135173
// If we've now emptied the items queue and we're not getting any more, complete.
@@ -144,12 +182,12 @@ public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationTo
144182
return new ValueTask<bool>(Task.FromCanceled<bool>(cancellationToken));
145183
}
146184

147-
if (!_parent._items.IsEmpty)
185+
if (_parent._items.IsThreadSafe && !_parent._items.IsEmpty)
148186
{
149187
return new ValueTask<bool>(true);
150188
}
151189

152-
UnboundedChannel<T> parent = _parent;
190+
UnboundedChannel<T, TQueue> parent = _parent;
153191

154192
lock (parent.SyncObj)
155193
{
@@ -192,15 +230,15 @@ public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationTo
192230
}
193231

194232
[DebuggerDisplay("Items = {ItemsCountForDebugger}")]
195-
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))]
233+
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<,>))]
196234
private sealed class UnboundedChannelWriter : ChannelWriter<T>, IDebugEnumerable<T>
197235
{
198-
internal readonly UnboundedChannel<T> _parent;
199-
internal UnboundedChannelWriter(UnboundedChannel<T> parent) => _parent = parent;
236+
internal readonly UnboundedChannel<T, TQueue> _parent;
237+
internal UnboundedChannelWriter(UnboundedChannel<T, TQueue> parent) => _parent = parent;
200238

201239
public override bool TryComplete(Exception? error)
202240
{
203-
UnboundedChannel<T> parent = _parent;
241+
UnboundedChannel<T, TQueue> parent = _parent;
204242
bool completeTask;
205243

206244
lock (parent.SyncObj)
@@ -240,7 +278,7 @@ public override bool TryComplete(Exception? error)
240278

241279
public override bool TryWrite(T item)
242280
{
243-
UnboundedChannel<T> parent = _parent;
281+
UnboundedChannel<T, TQueue> parent = _parent;
244282
while (true)
245283
{
246284
AsyncOperation<T>? blockedReader = null;
@@ -321,7 +359,7 @@ public override ValueTask WriteAsync(T item, CancellationToken cancellationToken
321359
}
322360

323361
/// <summary>Gets the object used to synchronize access to all state on this instance.</summary>
324-
private object SyncObj => _items;
362+
private object SyncObj => _blockedReaders;
325363

326364
[Conditional("DEBUG")]
327365
private void AssertInvariants()

0 commit comments

Comments
 (0)