@@ -32,12 +32,15 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
3232 guard active else {
3333 return nil
3434 }
35+
3536 let generation = channel. establish ( )
37+ let nextTokenStatus = ManagedCriticalState < ChannelTokenStatus > ( . new)
38+
3639 do {
37- let value : Element ? = try await withTaskCancellationHandler { [ channel] in
38- channel. cancel ( generation)
40+ let value = try await withTaskCancellationHandler { [ channel] in
41+ channel. cancelNext ( nextTokenStatus , generation)
3942 } operation: {
40- try await channel. next ( generation)
43+ try await channel. next ( nextTokenStatus , generation)
4144 }
4245
4346 if let value = value {
@@ -52,72 +55,49 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
5255 }
5356 }
5457 }
55-
56- struct Awaiting : Hashable {
58+
59+ typealias Pending = ChannelToken < UnsafeContinuation < UnsafeContinuation < Element ? , Error > ? , Never > >
60+ typealias Awaiting = ChannelToken < UnsafeContinuation < Element ? , Error > >
61+
62+ struct ChannelToken < Continuation> : Hashable {
5763 var generation : Int
58- var continuation : UnsafeContinuation < Element ? , Error > ?
59- let cancelled : Bool
60-
61- init ( generation: Int , continuation: UnsafeContinuation < Element ? , Error > ) {
64+ var continuation : Continuation ?
65+
66+ init ( generation: Int , continuation: Continuation ) {
6267 self . generation = generation
6368 self . continuation = continuation
64- cancelled = false
6569 }
66-
70+
6771 init ( placeholder generation: Int ) {
6872 self . generation = generation
6973 self . continuation = nil
70- cancelled = false
7174 }
72-
73- init ( cancelled generation: Int ) {
74- self . generation = generation
75- self . continuation = nil
76- cancelled = true
77- }
78-
75+
7976 func hash( into hasher: inout Hasher ) {
8077 hasher. combine ( generation)
8178 }
82-
83- static func == ( _ lhs: Awaiting , _ rhs: Awaiting ) -> Bool {
79+
80+ static func == ( _ lhs: ChannelToken , _ rhs: ChannelToken ) -> Bool {
8481 return lhs. generation == rhs. generation
8582 }
8683 }
8784
85+
86+ enum ChannelTokenStatus : Equatable {
87+ case new
88+ case cancelled
89+ }
90+
8891 enum Termination {
8992 case finished
9093 case failed( Error )
9194 }
9295
9396 enum Emission {
9497 case idle
95- case pending( [ UnsafeContinuation < UnsafeContinuation < Element ? , Error > ? , Never > ] )
98+ case pending( Set < Pending > )
9699 case awaiting( Set < Awaiting > )
97100 case terminated( Termination )
98-
99- var isTerminated : Bool {
100- guard case . terminated = self else { return false }
101- return true
102- }
103-
104- mutating func cancel( _ generation: Int ) -> UnsafeContinuation < Element ? , Error > ? {
105- switch self {
106- case . awaiting( var awaiting) :
107- let continuation = awaiting. remove ( Awaiting ( placeholder: generation) ) ? . continuation
108- if awaiting. isEmpty {
109- self = . idle
110- } else {
111- self = . awaiting( awaiting)
112- }
113- return continuation
114- case . idle:
115- self = . awaiting( [ Awaiting ( cancelled: generation) ] )
116- return nil
117- default :
118- return nil
119- }
120- }
121101 }
122102
123103 struct State {
@@ -135,19 +115,45 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
135115 return state. generation
136116 }
137117 }
138-
139- func cancel( _ generation: Int ) {
140- state. withCriticalRegion { state in
141- state. emission. cancel ( generation)
118+
119+ func cancelNext( _ nextTokenStatus: ManagedCriticalState < ChannelTokenStatus > , _ generation: Int ) {
120+ state. withCriticalRegion { state -> UnsafeContinuation < Element ? , Error > ? in
121+ let continuation : UnsafeContinuation < Element ? , Error > ?
122+
123+ switch state. emission {
124+ case . awaiting( var nexts) :
125+ continuation = nexts. remove ( Awaiting ( placeholder: generation) ) ? . continuation
126+ if nexts. isEmpty {
127+ state. emission = . idle
128+ } else {
129+ state. emission = . awaiting( nexts)
130+ }
131+ default :
132+ continuation = nil
133+ }
134+
135+ nextTokenStatus. withCriticalRegion { status in
136+ if status == . new {
137+ status = . cancelled
138+ }
139+ }
140+
141+ return continuation
142142 } ? . resume ( returning: nil )
143143 }
144144
145- func next( _ generation: Int ) async throws -> Element ? {
146- return try await withUnsafeThrowingContinuation { continuation in
145+ func next( _ nextTokenStatus : ManagedCriticalState < ChannelTokenStatus > , _ generation: Int ) async throws -> Element ? {
146+ return try await withUnsafeThrowingContinuation { ( continuation: UnsafeContinuation < Element ? , Error > ) in
147147 var cancelled = false
148148 var potentialTermination : Termination ?
149149
150150 state. withCriticalRegion { state -> UnsafeResumption < UnsafeContinuation < Element ? , Error > ? , Never > ? in
151+
152+ if nextTokenStatus. withCriticalRegion ( { $0 } ) == . cancelled {
153+ cancelled = true
154+ return nil
155+ }
156+
151157 switch state. emission {
152158 case . idle:
153159 state. emission = . awaiting( [ Awaiting ( generation: generation, continuation: continuation) ] )
@@ -159,17 +165,10 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
159165 } else {
160166 state. emission = . pending( sends)
161167 }
162- return UnsafeResumption ( continuation: send, success: continuation)
168+ return UnsafeResumption ( continuation: send. continuation , success: continuation)
163169 case . awaiting( var nexts) :
164- if nexts. update ( with: Awaiting ( generation: generation, continuation: continuation) ) != nil {
165- nexts. remove ( Awaiting ( placeholder: generation) )
166- cancelled = true
167- }
168- if nexts. isEmpty {
169- state. emission = . idle
170- } else {
171- state. emission = . awaiting( nexts)
172- }
170+ nexts. update ( with: Awaiting ( generation: generation, continuation: continuation) )
171+ state. emission = . awaiting( nexts)
173172 return nil
174173 case . terminated( let termination) :
175174 potentialTermination = termination
@@ -196,8 +195,67 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
196195 }
197196 }
198197
198+ func cancelSend( _ sendTokenStatus: ManagedCriticalState < ChannelTokenStatus > , _ generation: Int ) {
199+ state. withCriticalRegion { state -> UnsafeContinuation < UnsafeContinuation < Element ? , Error > ? , Never > ? in
200+ let continuation : UnsafeContinuation < UnsafeContinuation < Element ? , Error > ? , Never > ?
201+
202+ switch state. emission {
203+ case . pending( var sends) :
204+ let send = sends. remove ( Pending ( placeholder: generation) )
205+ if sends. isEmpty {
206+ state. emission = . idle
207+ } else {
208+ state. emission = . pending( sends)
209+ }
210+ continuation = send? . continuation
211+ default :
212+ continuation = nil
213+ }
214+
215+ sendTokenStatus. withCriticalRegion { status in
216+ if status == . new {
217+ status = . cancelled
218+ }
219+ }
220+
221+ return continuation
222+ } ? . resume ( returning: nil )
223+ }
224+
225+ func send( _ sendTokenStatus: ManagedCriticalState < ChannelTokenStatus > , _ generation: Int , _ element: Element ) async {
226+ let continuation : UnsafeContinuation < Element ? , Error > ? = await withUnsafeContinuation { continuation in
227+ state. withCriticalRegion { state -> UnsafeResumption < UnsafeContinuation < Element ? , Error > ? , Never > ? in
228+
229+ if sendTokenStatus. withCriticalRegion ( { $0 } ) == . cancelled {
230+ return UnsafeResumption ( continuation: continuation, success: nil )
231+ }
232+
233+ switch state. emission {
234+ case . idle:
235+ state. emission = . pending( [ Pending ( generation: generation, continuation: continuation) ] )
236+ return nil
237+ case . pending( var sends) :
238+ sends. update ( with: Pending ( generation: generation, continuation: continuation) )
239+ state. emission = . pending( sends)
240+ return nil
241+ case . awaiting( var nexts) :
242+ let next = nexts. removeFirst ( ) . continuation
243+ if nexts. count == 0 {
244+ state. emission = . idle
245+ } else {
246+ state. emission = . awaiting( nexts)
247+ }
248+ return UnsafeResumption ( continuation: continuation, success: next)
249+ case . terminated:
250+ return UnsafeResumption ( continuation: continuation, success: nil )
251+ }
252+ } ? . resume ( )
253+ }
254+ continuation? . resume ( returning: element)
255+ }
256+
199257 func terminateAll( error: Failure ? = nil ) {
200- let ( sends, nexts) = state. withCriticalRegion { state -> ( [ UnsafeContinuation < UnsafeContinuation < Element ? , Error > ? , Never > ] , Set < Awaiting > ) in
258+ let ( sends, nexts) = state. withCriticalRegion { state -> ( Set < Pending > , Set < Awaiting > ) in
201259
202260 let nextState : Emission
203261 if let error = error {
@@ -222,7 +280,7 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
222280 }
223281
224282 for send in sends {
225- send. resume ( returning: nil )
283+ send. continuation ? . resume ( returning: nil )
226284 }
227285
228286 if let error = error {
@@ -234,45 +292,20 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
234292 next. continuation? . resume ( returning: nil )
235293 }
236294 }
237-
238- }
239-
240- func _send( _ element: Element ) async {
241- await withTaskCancellationHandler {
242- terminateAll ( )
243- } operation: {
244- let continuation : UnsafeContinuation < Element ? , Error > ? = await withUnsafeContinuation { continuation in
245- state. withCriticalRegion { state -> UnsafeResumption < UnsafeContinuation < Element ? , Error > ? , Never > ? in
246- switch state. emission {
247- case . idle:
248- state. emission = . pending( [ continuation] )
249- return nil
250- case . pending( var sends) :
251- sends. append ( continuation)
252- state. emission = . pending( sends)
253- return nil
254- case . awaiting( var nexts) :
255- let next = nexts. removeFirst ( ) . continuation
256- if nexts. count == 0 {
257- state. emission = . idle
258- } else {
259- state. emission = . awaiting( nexts)
260- }
261- return UnsafeResumption ( continuation: continuation, success: next)
262- case . terminated:
263- return UnsafeResumption ( continuation: continuation, success: nil )
264- }
265- } ? . resume ( )
266- }
267- continuation? . resume ( returning: element)
268- }
269295 }
270296
271297 /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made
272298 /// or when a call to `finish()`/`fail(_:)` is made from another Task.
273299 /// If the channel is already finished then this returns immediately
274300 public func send( _ element: Element ) async {
275- await _send ( element)
301+ let generation = establish ( )
302+ let sendTokenStatus = ManagedCriticalState < ChannelTokenStatus > ( . new)
303+
304+ await withTaskCancellationHandler { [ weak self] in
305+ self ? . cancelSend ( sendTokenStatus, generation)
306+ } operation: {
307+ await send ( sendTokenStatus, generation, element)
308+ }
276309 }
277310
278311 /// Send an error to all awaiting iterations.
0 commit comments