4
4
5
5
using System . Buffers ;
6
6
using System . Diagnostics ;
7
+ using System . Runtime . CompilerServices ;
7
8
using System . Threading ;
8
9
using System . Threading . Tasks ;
10
+ using System . Threading . Tasks . Sources ;
9
11
10
12
namespace System . IO . Pipelines
11
13
{
12
- internal class StreamPipeReader : PipeReader
14
+ internal class StreamPipeReader : PipeReader , IValueTaskSource < ReadResult >
13
15
{
14
16
internal const int InitialSegmentPoolSize = 4 ; // 16K
15
17
internal const int MaxSegmentPoolSize = 256 ; // 1MB
@@ -28,11 +30,19 @@ internal class StreamPipeReader : PipeReader
28
30
private BufferSegment _readTail ;
29
31
private long _bufferedBytes ;
30
32
private bool _examinedEverything ;
31
- private object _lock = new object ( ) ;
33
+ private readonly object _lock = new object ( ) ;
32
34
33
35
// Mutable struct! Don't make this readonly
34
36
private BufferSegmentStack _bufferSegmentPool ;
35
- private bool _leaveOpen ;
37
+ private readonly bool _leaveOpen ;
38
+
39
+ // State for async reads
40
+ private volatile bool _readInProgress ;
41
+ private readonly Action _onReadCompleted ;
42
+ private ManualResetValueTaskSourceCore < ReadResult > _readMrvts ;
43
+ private ValueTaskAwaiter < int > _readAwaiter ;
44
+ private CancellationToken _readCancellation ;
45
+ private CancellationTokenRegistration _readRegistration ;
36
46
37
47
/// <summary>
38
48
/// Creates a new StreamPipeReader.
@@ -53,6 +63,7 @@ public StreamPipeReader(Stream readingStream, StreamPipeReaderOptions options)
53
63
_pool = options . Pool == MemoryPool < byte > . Shared ? null : options . Pool ;
54
64
_bufferSize = _pool == null ? options . BufferSize : Math . Min ( options . BufferSize , _pool . MaxBufferSize ) ;
55
65
_leaveOpen = options . LeaveOpen ;
66
+ _onReadCompleted = new Action ( OnReadCompleted ) ;
56
67
}
57
68
58
69
/// <summary>
@@ -72,11 +83,7 @@ private CancellationTokenSource InternalTokenSource
72
83
{
73
84
lock ( _lock )
74
85
{
75
- if ( _internalTokenSource == null )
76
- {
77
- _internalTokenSource = new CancellationTokenSource ( ) ;
78
- }
79
- return _internalTokenSource ;
86
+ return ( _internalTokenSource ??= new CancellationTokenSource ( ) ) ;
80
87
}
81
88
}
82
89
}
@@ -193,39 +200,59 @@ public override void OnWriterCompleted(Action<Exception, object> callback, objec
193
200
}
194
201
195
202
/// <inheritdoc />
196
- public override async ValueTask < ReadResult > ReadAsync ( CancellationToken cancellationToken = default )
203
+ public override ValueTask < ReadResult > ReadAsync ( CancellationToken cancellationToken = default )
197
204
{
198
- // TODO ReadyAsync needs to throw if there are overlapping reads.
199
- ThrowIfCompleted ( ) ;
200
-
201
- // PERF: store InternalTokenSource locally to avoid querying it twice (which acquires a lock)
202
- CancellationTokenSource tokenSource = InternalTokenSource ;
203
- if ( TryReadInternal ( tokenSource , out ReadResult readResult ) )
205
+ if ( _readInProgress )
204
206
{
205
- return readResult ;
207
+ // Throw if there are overlapping reads; throwing unwrapped as it suggests last read was not awaited
208
+ // so we surface it directly rather than wrapped in a Task (as this one will likely also not be awaited).
209
+ ThrowConcurrentReadsNotSupported ( ) ;
206
210
}
211
+ _readInProgress = true ;
207
212
208
- if ( _isStreamCompleted )
213
+ bool isAsync = false ;
214
+ try
209
215
{
210
- return new ReadResult ( buffer : default , isCanceled : false , isCompleted : true ) ;
211
- }
212
216
213
- var reg = new CancellationTokenRegistration ( ) ;
214
- if ( cancellationToken . CanBeCanceled )
215
- {
216
- reg = cancellationToken . UnsafeRegister ( state => ( ( StreamPipeReader ) state ) . Cancel ( ) , this ) ;
217
- }
217
+ ThrowIfCompleted ( ) ;
218
+
219
+ // PERF: store InternalTokenSource locally to avoid querying it twice (which acquires a lock)
220
+ CancellationTokenSource tokenSource = InternalTokenSource ;
221
+ if ( TryReadInternal ( tokenSource , out ReadResult readResult ) )
222
+ {
223
+ return new ValueTask < ReadResult > ( readResult ) ;
224
+ }
225
+
226
+ if ( _isStreamCompleted )
227
+ {
228
+ return new ValueTask < ReadResult > ( new ReadResult ( buffer : default , isCanceled : false , isCompleted : true ) ) ;
229
+ }
230
+
231
+ var reg = new CancellationTokenRegistration ( ) ;
232
+ if ( cancellationToken . CanBeCanceled )
233
+ {
234
+ reg = cancellationToken . UnsafeRegister ( state => ( ( StreamPipeReader ) state ) . Cancel ( ) , this ) ;
235
+ }
218
236
219
- using ( reg )
220
- {
221
237
var isCanceled = false ;
222
238
try
223
239
{
224
240
AllocateReadTail ( ) ;
225
241
226
242
Memory < byte > buffer = _readTail . AvailableMemory . Slice ( _readTail . End ) ;
227
243
228
- int length = await InnerStream . ReadAsync ( buffer , tokenSource . Token ) . ConfigureAwait ( false ) ;
244
+ ValueTask < int > resultTask = InnerStream . ReadAsync ( buffer , tokenSource . Token ) ;
245
+ int length ;
246
+ if ( resultTask . IsCompletedSuccessfully )
247
+ {
248
+ length = resultTask . Result ;
249
+ }
250
+ else
251
+ {
252
+ isAsync = true ;
253
+ // Need to go async
254
+ return CompleteReadAsync ( resultTask , cancellationToken , reg ) ;
255
+ }
229
256
230
257
Debug . Assert ( length + _readTail . End <= _readTail . AvailableMemory . Length ) ;
231
258
@@ -252,8 +279,27 @@ public override async ValueTask<ReadResult> ReadAsync(CancellationToken cancella
252
279
}
253
280
254
281
}
282
+ finally
283
+ {
284
+ if ( ! isAsync )
285
+ {
286
+ reg . Dispose ( ) ;
287
+ }
288
+ }
255
289
256
- return new ReadResult ( GetCurrentReadOnlySequence ( ) , isCanceled , _isStreamCompleted ) ;
290
+ return new ValueTask < ReadResult > ( new ReadResult ( GetCurrentReadOnlySequence ( ) , isCanceled , _isStreamCompleted ) ) ;
291
+ }
292
+ catch ( Exception ex )
293
+ {
294
+ return new ValueTask < ReadResult > ( Task . FromException < ReadResult > ( ex ) ) ;
295
+ }
296
+ finally
297
+ {
298
+ if ( ! isAsync )
299
+ {
300
+ Debug . Assert ( _readInProgress ) ;
301
+ _readInProgress = false ;
302
+ }
257
303
}
258
304
}
259
305
@@ -275,6 +321,11 @@ private void ThrowIfCompleted()
275
321
276
322
public override bool TryRead ( out ReadResult result )
277
323
{
324
+ if ( _readInProgress )
325
+ {
326
+ ThrowConcurrentReadsNotSupported ( ) ;
327
+ }
328
+
278
329
ThrowIfCompleted ( ) ;
279
330
280
331
return TryReadInternal ( InternalTokenSource , out result ) ;
@@ -362,5 +413,113 @@ private void Cancel()
362
413
{
363
414
InternalTokenSource . Cancel ( ) ;
364
415
}
416
+
417
+ static void ThrowConcurrentReadsNotSupported ( )
418
+ {
419
+ throw new InvalidOperationException ( $ "Concurrent reads are not supported; await the { nameof ( ValueTask < ReadResult > ) } before starting next read.") ;
420
+ }
421
+
422
+ private ValueTask < ReadResult > CompleteReadAsync ( ValueTask < int > task , CancellationToken cancellationToken , CancellationTokenRegistration reg )
423
+ {
424
+ Debug . Assert ( _readInProgress , "Read not in progress" ) ;
425
+
426
+ _readCancellation = cancellationToken ;
427
+ _readRegistration = reg ;
428
+
429
+ _readAwaiter = task . GetAwaiter ( ) ;
430
+
431
+ return new ValueTask < ReadResult > ( this , _readMrvts . Version ) ;
432
+ }
433
+
434
+ private void OnReadCompleted ( )
435
+ {
436
+ try
437
+ {
438
+ int length = _readAwaiter . GetResult ( ) ;
439
+
440
+ Debug . Assert ( length + _readTail . End <= _readTail . AvailableMemory . Length ) ;
441
+
442
+ _readTail . End += length ;
443
+ _bufferedBytes += length ;
444
+
445
+ if ( length == 0 )
446
+ {
447
+ _isStreamCompleted = true ;
448
+ }
449
+
450
+ _readMrvts . SetResult ( new ReadResult ( GetCurrentReadOnlySequence ( ) , isCanceled : false , _isStreamCompleted ) ) ;
451
+ }
452
+ catch ( OperationCanceledException oce )
453
+ {
454
+ // Get the source before clearing (and replacing)
455
+ CancellationTokenSource tokenSource = InternalTokenSource ;
456
+ ClearCancellationToken ( ) ;
457
+ if ( tokenSource . IsCancellationRequested && ! _readCancellation . IsCancellationRequested )
458
+ {
459
+ // Catch cancellation and translate it into setting isCanceled = true
460
+ _readMrvts . SetResult ( new ReadResult ( GetCurrentReadOnlySequence ( ) , isCanceled : true , _isStreamCompleted ) ) ;
461
+ }
462
+ else
463
+ {
464
+ _readMrvts . SetException ( oce ) ;
465
+ }
466
+ }
467
+ catch ( Exception ex )
468
+ {
469
+ _readMrvts . SetException ( ex ) ;
470
+ }
471
+ finally
472
+ {
473
+ _readRegistration . Dispose ( ) ;
474
+ _readRegistration = default ;
475
+ }
476
+ }
477
+
478
+ ReadResult IValueTaskSource < ReadResult > . GetResult ( short token )
479
+ {
480
+ ValidateReading ( ) ;
481
+ ReadResult result = _readMrvts . GetResult ( token ) ;
482
+
483
+ _readCancellation = default ;
484
+ _readAwaiter = default ;
485
+ _readMrvts . Reset ( ) ;
486
+
487
+ Debug . Assert ( _readInProgress ) ;
488
+ _readInProgress = false ;
489
+
490
+ return result ;
491
+ }
492
+
493
+ ValueTaskSourceStatus IValueTaskSource < ReadResult > . GetStatus ( short token )
494
+ => _readMrvts . GetStatus ( token ) ;
495
+
496
+ void IValueTaskSource < ReadResult > . OnCompleted ( Action < object > continuation , object state , short token , ValueTaskSourceOnCompletedFlags flags )
497
+ {
498
+ ValidateReading ( ) ;
499
+ _readMrvts . OnCompleted ( continuation , state , token , flags ) ;
500
+
501
+ if ( ( flags & ValueTaskSourceOnCompletedFlags . FlowExecutionContext ) != 0 )
502
+ {
503
+ _readAwaiter . OnCompleted ( _onReadCompleted ) ;
504
+ }
505
+ else
506
+ {
507
+ _readAwaiter . UnsafeOnCompleted ( _onReadCompleted ) ;
508
+ }
509
+ }
510
+
511
+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
512
+ private void ValidateReading ( )
513
+ {
514
+ if ( ! _readInProgress )
515
+ {
516
+ ThrowReadNotInProgress ( ) ;
517
+ }
518
+
519
+ static void ThrowReadNotInProgress ( )
520
+ {
521
+ throw new InvalidOperationException ( "Read not in progress" ) ;
522
+ }
523
+ }
365
524
}
366
525
}
0 commit comments