Skip to content

Commit 3cc8ce8

Browse files
authored
Use Channel Instead of BufferBlock (#5123)
* Update file to use channel * Add channels package * Update for feedback * Update more buffer blocks to use channel * Add version to props file * Remove build dependencies that aren't needed * Update comments * Add back data flow package * Remove data flow package everywhere * Update from PR feedback * Revert carriage return * Updates from comments * Remove disposed variable * Block receiving thread
1 parent a00a222 commit 3cc8ce8

File tree

6 files changed

+44
-64
lines changed

6 files changed

+44
-64
lines changed

build/Dependencies.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
<SystemCollectionsImmutableVersion>1.5.0</SystemCollectionsImmutableVersion>
99
<SystemMemoryVersion>4.5.1</SystemMemoryVersion>
1010
<SystemReflectionEmitLightweightPackageVersion>4.3.0</SystemReflectionEmitLightweightPackageVersion>
11-
<SystemThreadingTasksDataflowPackageVersion>4.8.0</SystemThreadingTasksDataflowPackageVersion>
11+
<SystemThreadingChannelsPackageVersion>4.7.1</SystemThreadingChannelsPackageVersion>
1212
</PropertyGroup>
1313

1414
<!-- Other/Non-Core Product Dependencies -->

pkg/Microsoft.ML/Microsoft.ML.nupkgproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
1313
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="$(SystemReflectionEmitLightweightPackageVersion)" />
14-
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="$(SystemThreadingTasksDataflowPackageVersion)" />
1514
<PackageReference Include="System.CodeDom" Version="$(SystemCodeDomPackageVersion)" />
1615
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
1716
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
17+
<PackageReference Include="System.Threading.Channels" Version="$(SystemThreadingChannelsPackageVersion)" />
1818
</ItemGroup>
1919

2020
<ItemGroup>

src/Microsoft.ML.Data/Microsoft.ML.Data.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
<ItemGroup>
1111
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
1212
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
13-
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="$(SystemThreadingTasksDataflowPackageVersion)" />
13+
<PackageReference Include="System.Threading.Channels" Version="$(SystemThreadingChannelsPackageVersion)" />
1414
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="$(SystemReflectionEmitLightweightPackageVersion)" />
1515
</ItemGroup>
1616

src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs

Lines changed: 34 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
using System.Collections.Generic;
77
using System.Linq;
88
using System.Threading;
9+
using System.Threading.Channels;
910
using System.Threading.Tasks;
10-
using System.Threading.Tasks.Dataflow;
1111
using Microsoft.ML;
1212
using Microsoft.ML.CommandLine;
1313
using Microsoft.ML.Data;
@@ -487,13 +487,12 @@ private static readonly FuncInstanceMethodInfo1<Cursor, int, Delegate> _createGe
487487
private int _liveCount;
488488
private bool _doneConsuming;
489489

490-
private readonly BufferBlock<int> _toProduce;
491-
private readonly BufferBlock<int> _toConsume;
490+
private readonly Channel<int> _toProduceChannel;
491+
private readonly Channel<int> _toConsumeChannel;
492492
private readonly Task _producerTask;
493493
private Exception _producerTaskException;
494494

495495
private readonly int[] _colToActivesIndex;
496-
private bool _disposed;
497496

498497
public override DataViewSchema Schema => _input.Schema;
499498

@@ -542,46 +541,20 @@ public Cursor(IChannelProvider provider, int poolRows, DataViewRowCursor input,
542541
_liveCount = 1;
543542

544543
// Set up the producer worker.
545-
_toConsume = new BufferBlock<int>();
546-
_toProduce = new BufferBlock<int>();
544+
_toConsumeChannel = Channel.CreateUnbounded<int>(new UnboundedChannelOptions { SingleWriter = true });
545+
_toProduceChannel = Channel.CreateUnbounded<int>(new UnboundedChannelOptions { SingleWriter = true });
547546
// First request the pool - 1 + block size rows, to get us going.
548-
PostAssert(_toProduce, _poolRows - 1 + _blockSize);
547+
PostAssert(_toProduceChannel, _poolRows - 1 + _blockSize);
549548
// Queue up the remaining capacity.
550549
for (int i = 1; i < _bufferDepth; ++i)
551-
PostAssert(_toProduce, _blockSize);
550+
PostAssert(_toProduceChannel, _blockSize);
552551

553552
_producerTask = ProduceAsync();
554553
}
555554

556-
protected override void Dispose(bool disposing)
555+
public static void PostAssert<T>(Channel<T> target, T item)
557556
{
558-
if (_disposed)
559-
return;
560-
561-
if (disposing)
562-
{
563-
_toProduce.Complete();
564-
_producerTask.Wait();
565-
566-
// Complete the consumer after the producerTask has finished, since producerTask could
567-
// have posted more items to _toConsume.
568-
_toConsume.Complete();
569-
570-
// Drain both BufferBlocks - this prevents what appears to be memory leaks when using the VS Debugger
571-
// because if a BufferBlock still contains items, its underlying Tasks are not getting completed.
572-
// See https://github.com/dotnet/corefx/issues/30582 for the VS Debugger issue.
573-
// See also https://github.com/dotnet/machinelearning/issues/4399
574-
_toProduce.TryReceiveAll(out _);
575-
_toConsume.TryReceiveAll(out _);
576-
}
577-
578-
_disposed = true;
579-
base.Dispose(disposing);
580-
}
581-
582-
public static void PostAssert<T>(ITargetBlock<T> target, T item)
583-
{
584-
bool retval = target.Post(item);
557+
bool retval = target.Writer.TryWrite(item);
585558
Contracts.Assert(retval);
586559
}
587560

@@ -595,12 +568,13 @@ private async Task ProduceAsync()
595568
try
596569
{
597570
int circularIndex = 0;
598-
while (await _toProduce.OutputAvailableAsync().ConfigureAwait(false))
571+
while (await _toProduceChannel.Reader.WaitToReadAsync().ConfigureAwait(false))
599572
{
600573
int requested;
601-
if (!_toProduce.TryReceive(out requested))
574+
if (!_toProduceChannel.Reader.TryRead(out requested))
602575
{
603-
// OutputAvailableAsync returned true, but TryReceive returned false -
576+
// The producer Channel's Reader.WaitToReadAsync returned true,
577+
// but the Reader's TryRead returned false -
604578
// so loop back around and try again.
605579
continue;
606580
}
@@ -619,14 +593,14 @@ private async Task ProduceAsync()
619593
if (circularIndex == _pipeIndices.Length)
620594
circularIndex = 0;
621595
}
622-
PostAssert(_toConsume, numRows);
596+
PostAssert(_toConsumeChannel, numRows);
623597
if (numRows < requested)
624598
{
625599
// We've reached the end of the cursor. Send the sentinel, then exit.
626600
// This assumes that the receiver will receive things in Post order
627601
// (so that the sentinel is received, after the last Post).
628602
if (numRows > 0)
629-
PostAssert(_toConsume, 0);
603+
PostAssert(_toConsumeChannel, 0);
630604
return;
631605
}
632606
}
@@ -635,7 +609,7 @@ private async Task ProduceAsync()
635609
{
636610
_producerTaskException = ex;
637611
// Send the sentinel in this case as well, the field will be checked.
638-
PostAssert(_toConsume, 0);
612+
PostAssert(_toConsumeChannel, 0);
639613
}
640614
}
641615

@@ -652,26 +626,32 @@ protected override bool MoveNextCore()
652626
{
653627
// We should let the producer know it can give us more stuff.
654628
// It is possible for int values to be sent beyond the
655-
// end of the sentinel, but we suppose this is irrelevant.
656-
PostAssert(_toProduce, _deadCount);
629+
// end of the Channel, but we suppose this is irrelevant.
630+
PostAssert(_toProduceChannel, _deadCount);
657631
_deadCount = 0;
658632
}
659633

660634
while (_liveCount < _poolRows && !_doneConsuming)
661635
{
662636
// We are under capacity. Try to get some more.
663-
int got = _toConsume.Receive();
664-
if (got == 0)
637+
while (_toConsumeChannel.Reader.WaitToReadAsync().GetAwaiter().GetResult())
665638
{
666-
// We've reached the end sentinel. There's no reason
667-
// to attempt further communication with the producer.
668-
// Check whether something horrible happened.
669-
if (_producerTaskException != null)
670-
throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception");
671-
_doneConsuming = true;
672-
break;
639+
var hasReadItem = _toConsumeChannel.Reader.TryRead(out int got);
640+
if (hasReadItem)
641+
{
642+
if (got == 0)
643+
{
644+
// We've reached the end of the Channel. There's no reason
645+
// to attempt further communication with the producer.
646+
// Check whether something horrible happened.
647+
if (_producerTaskException != null)
648+
throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception");
649+
_doneConsuming = true;
650+
break;
651+
}
652+
_liveCount += got;
653+
}
673654
}
674-
_liveCount += got;
675655
}
676656
if (_liveCount == 0)
677657
return false;

src/Microsoft.ML.Sweeper/AsyncSweeper.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Threading;
8+
using System.Threading.Channels;
89
using System.Threading.Tasks;
9-
using System.Threading.Tasks.Dataflow;
1010
using Microsoft.ML;
1111
using Microsoft.ML.CommandLine;
1212
using Microsoft.ML.Internal.Utilities;
@@ -168,7 +168,7 @@ public sealed class Options
168168
private readonly object _lock;
169169
private readonly CancellationTokenSource _cts;
170170

171-
private readonly BufferBlock<ParameterSetWithId> _paramQueue;
171+
private readonly Channel<ParameterSetWithId> _paramChannel;
172172
private readonly int _relaxation;
173173
private readonly ISweeper _baseSweeper;
174174
private readonly IHost _host;
@@ -208,7 +208,8 @@ public DeterministicSweeperAsync(IHostEnvironment env, Options options)
208208
_lock = new object();
209209
_results = new List<IRunResult>();
210210
_nullRuns = new HashSet<int>();
211-
_paramQueue = new BufferBlock<ParameterSetWithId>();
211+
_paramChannel = Channel.CreateUnbounded<ParameterSetWithId>(
212+
new UnboundedChannelOptions { SingleWriter = true });
212213

213214
PrepareNextBatch(null);
214215
}
@@ -220,12 +221,12 @@ private void PrepareNextBatch(IEnumerable<IRunResult> results)
220221
if (Utils.Size(paramSets) == 0)
221222
{
222223
// Mark the queue as completed.
223-
_paramQueue.Complete();
224+
_paramChannel.Writer.Complete();
224225
return;
225226
}
226227
// Assign an id to each ParameterSet and enque it.
227228
foreach (var paramSet in paramSets)
228-
_paramQueue.Post(new ParameterSetWithId(_numGenerated++, paramSet));
229+
_paramChannel.Writer.TryWrite(new ParameterSetWithId(_numGenerated++, paramSet));
229230
EnsureResultsSize();
230231
}
231232

@@ -278,7 +279,7 @@ public async Task<ParameterSetWithId> ProposeAsync()
278279
return null;
279280
try
280281
{
281-
return await _paramQueue.ReceiveAsync(_cts.Token);
282+
return await _paramChannel.Reader.ReadAsync(_cts.Token);
282283
}
283284
catch (InvalidOperationException)
284285
{

test/Microsoft.ML.FSharp.Tests/SmokeTests.fs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Google.Protobuf.dll"
2323
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Newtonsoft.Json.dll"
2424
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.CodeDom.dll"
25-
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.Threading.Tasks.Dataflow.dll"
2625
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.CpuMath.dll"
2726
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Data.dll"
2827
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Transforms.dll"

0 commit comments

Comments
 (0)