66using System . Collections . Generic ;
77using System . Linq ;
88using System . Threading ;
9+ using System . Threading . Channels ;
910using System . Threading . Tasks ;
10- using System . Threading . Tasks . Dataflow ;
1111using Microsoft . ML ;
1212using Microsoft . ML . CommandLine ;
1313using Microsoft . ML . Data ;
@@ -487,13 +487,12 @@ private static readonly FuncInstanceMethodInfo1<Cursor, int, Delegate> _createGe
487487 private int _liveCount ;
488488 private bool _doneConsuming ;
489489
490- private readonly BufferBlock < int > _toProduce ;
491- private readonly BufferBlock < int > _toConsume ;
490+ private readonly Channel < int > _toProduceChannel ;
491+ private readonly Channel < int > _toConsumeChannel ;
492492 private readonly Task _producerTask ;
493493 private Exception _producerTaskException ;
494494
495495 private readonly int [ ] _colToActivesIndex ;
496- private bool _disposed ;
497496
498497 public override DataViewSchema Schema => _input . Schema ;
499498
@@ -542,46 +541,20 @@ public Cursor(IChannelProvider provider, int poolRows, DataViewRowCursor input,
542541 _liveCount = 1 ;
543542
544543 // Set up the producer worker.
545- _toConsume = new BufferBlock < int > ( ) ;
546- _toProduce = new BufferBlock < int > ( ) ;
544+ _toConsumeChannel = Channel . CreateUnbounded < int > ( new UnboundedChannelOptions { SingleWriter = true } ) ;
545+ _toProduceChannel = Channel . CreateUnbounded < int > ( new UnboundedChannelOptions { SingleWriter = true } ) ;
547546 // First request the pool - 1 + block size rows, to get us going.
548- PostAssert ( _toProduce , _poolRows - 1 + _blockSize ) ;
547+ PostAssert ( _toProduceChannel , _poolRows - 1 + _blockSize ) ;
549548 // Queue up the remaining capacity.
550549 for ( int i = 1 ; i < _bufferDepth ; ++ i )
551- PostAssert ( _toProduce , _blockSize ) ;
550+ PostAssert ( _toProduceChannel , _blockSize ) ;
552551
553552 _producerTask = ProduceAsync ( ) ;
554553 }
555554
556- protected override void Dispose ( bool disposing )
555+ public static void PostAssert < T > ( Channel < T > target , T item )
557556 {
558- if ( _disposed )
559- return ;
560-
561- if ( disposing )
562- {
563- _toProduce . Complete ( ) ;
564- _producerTask . Wait ( ) ;
565-
566- // Complete the consumer after the producerTask has finished, since producerTask could
567- // have posted more items to _toConsume.
568- _toConsume . Complete ( ) ;
569-
570- // Drain both BufferBlocks - this prevents what appears to be memory leaks when using the VS Debugger
571- // because if a BufferBlock still contains items, its underlying Tasks are not getting completed.
572- // See https://github.com/dotnet/corefx/issues/30582 for the VS Debugger issue.
573- // See also https://github.com/dotnet/machinelearning/issues/4399
574- _toProduce . TryReceiveAll ( out _ ) ;
575- _toConsume . TryReceiveAll ( out _ ) ;
576- }
577-
578- _disposed = true ;
579- base . Dispose ( disposing ) ;
580- }
581-
582- public static void PostAssert < T > ( ITargetBlock < T > target , T item )
583- {
584- bool retval = target . Post ( item ) ;
557+ bool retval = target . Writer . TryWrite ( item ) ;
585558 Contracts . Assert ( retval ) ;
586559 }
587560
@@ -595,12 +568,13 @@ private async Task ProduceAsync()
595568 try
596569 {
597570 int circularIndex = 0 ;
598- while ( await _toProduce . OutputAvailableAsync ( ) . ConfigureAwait ( false ) )
571+ while ( await _toProduceChannel . Reader . WaitToReadAsync ( ) . ConfigureAwait ( false ) )
599572 {
600573 int requested ;
601- if ( ! _toProduce . TryReceive ( out requested ) )
574+ if ( ! _toProduceChannel . Reader . TryRead ( out requested ) )
602575 {
603- // OutputAvailableAsync returned true, but TryReceive returned false -
576+ // The producer Channel's Reader.WaitToReadAsync returned true,
577+ // but the Reader's TryRead returned false -
604578 // so loop back around and try again.
605579 continue ;
606580 }
@@ -619,14 +593,14 @@ private async Task ProduceAsync()
619593 if ( circularIndex == _pipeIndices . Length )
620594 circularIndex = 0 ;
621595 }
622- PostAssert ( _toConsume , numRows ) ;
596+ PostAssert ( _toConsumeChannel , numRows ) ;
623597 if ( numRows < requested )
624598 {
625599 // We've reached the end of the cursor. Send the sentinel, then exit.
626600 // This assumes that the receiver will receive things in Post order
627601 // (so that the sentinel is received, after the last Post).
628602 if ( numRows > 0 )
629- PostAssert ( _toConsume , 0 ) ;
603+ PostAssert ( _toConsumeChannel , 0 ) ;
630604 return ;
631605 }
632606 }
@@ -635,7 +609,7 @@ private async Task ProduceAsync()
635609 {
636610 _producerTaskException = ex ;
637611 // Send the sentinel in this case as well, the field will be checked.
638- PostAssert ( _toConsume , 0 ) ;
612+ PostAssert ( _toConsumeChannel , 0 ) ;
639613 }
640614 }
641615
@@ -652,26 +626,32 @@ protected override bool MoveNextCore()
652626 {
653627 // We should let the producer know it can give us more stuff.
654628 // It is possible for int values to be sent beyond the
655- // end of the sentinel , but we suppose this is irrelevant.
656- PostAssert ( _toProduce , _deadCount ) ;
629+ // end of the Channel , but we suppose this is irrelevant.
630+ PostAssert ( _toProduceChannel , _deadCount ) ;
657631 _deadCount = 0 ;
658632 }
659633
660634 while ( _liveCount < _poolRows && ! _doneConsuming )
661635 {
662636 // We are under capacity. Try to get some more.
663- int got = _toConsume . Receive ( ) ;
664- if ( got == 0 )
637+ while ( _toConsumeChannel . Reader . WaitToReadAsync ( ) . GetAwaiter ( ) . GetResult ( ) )
665638 {
666- // We've reached the end sentinel. There's no reason
667- // to attempt further communication with the producer.
668- // Check whether something horrible happened.
669- if ( _producerTaskException != null )
670- throw Ch . Except ( _producerTaskException , "Shuffle input cursor reader failed with an exception" ) ;
671- _doneConsuming = true ;
672- break ;
639+ var hasReadItem = _toConsumeChannel . Reader . TryRead ( out int got ) ;
640+ if ( hasReadItem )
641+ {
642+ if ( got == 0 )
643+ {
644+ // We've reached the end of the Channel. There's no reason
645+ // to attempt further communication with the producer.
646+ // Check whether something horrible happened.
647+ if ( _producerTaskException != null )
648+ throw Ch . Except ( _producerTaskException , "Shuffle input cursor reader failed with an exception" ) ;
649+ _doneConsuming = true ;
650+ break ;
651+ }
652+ _liveCount += got ;
653+ }
673654 }
674- _liveCount += got ;
675655 }
676656 if ( _liveCount == 0 )
677657 return false ;
0 commit comments