Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Channel Instead of BufferBlock #5123

Merged
merged 14 commits into from
Jul 8, 2020
1 change: 1 addition & 0 deletions build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<SystemMemoryVersion>4.5.1</SystemMemoryVersion>
<SystemReflectionEmitLightweightPackageVersion>4.3.0</SystemReflectionEmitLightweightPackageVersion>
<SystemThreadingTasksDataflowPackageVersion>4.8.0</SystemThreadingTasksDataflowPackageVersion>
<SystemThreadingChannelsPackageVersion>4.7.1</SystemThreadingChannelsPackageVersion>
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
</PropertyGroup>

<!-- Other/Non-Core Product Dependencies -->
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
<ItemGroup>
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
<PackageReference Include="System.Threading.Channels" Version="$(SystemThreadingChannelsPackageVersion)" />
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="$(SystemThreadingTasksDataflowPackageVersion)" />
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="$(SystemReflectionEmitLightweightPackageVersion)" />
</ItemGroup>
Expand Down
40 changes: 20 additions & 20 deletions src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -486,8 +486,8 @@ private static readonly FuncInstanceMethodInfo1<Cursor, int, Delegate> _createGe
private int _liveCount;
private bool _doneConsuming;

private readonly BufferBlock<int> _toProduce;
private readonly BufferBlock<int> _toConsume;
private readonly Channel<int> _toProduceChannel;
private readonly Channel<int> _toConsumeChannel;
private readonly Task _producerTask;
private Exception _producerTaskException;

Expand Down Expand Up @@ -541,13 +541,13 @@ public Cursor(IChannelProvider provider, int poolRows, DataViewRowCursor input,
_liveCount = 1;

// Set up the producer worker.
_toConsume = new BufferBlock<int>();
_toProduce = new BufferBlock<int>();
_toConsumeChannel = Channel.CreateUnbounded<int>(new UnboundedChannelOptions { SingleWriter = true });
_toProduceChannel = Channel.CreateUnbounded<int>(new UnboundedChannelOptions { SingleWriter = true });
// First request the pool - 1 + block size rows, to get us going.
PostAssert(_toProduce, _poolRows - 1 + _blockSize);
PostAssert(_toProduceChannel, _poolRows - 1 + _blockSize);
// Queue up the remaining capacity.
for (int i = 1; i < _bufferDepth; ++i)
PostAssert(_toProduce, _blockSize);
PostAssert(_toProduceChannel, _blockSize);

_producerTask = ProduceAsync();
}
Expand All @@ -559,28 +559,28 @@ protected override void Dispose(bool disposing)

if (disposing)
{
_toProduce.Complete();
_toProduceChannel.Writer.Complete();
_producerTask.Wait();

// Complete the consumer after the producerTask has finished, since producerTask could
// have posted more items to _toConsume.
_toConsume.Complete();
_toConsumeChannel.Writer.Complete();

// Drain both BufferBlocks - this prevents what appears to be memory leaks when using the VS Debugger
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
// because if a BufferBlock still contains items, its underlying Tasks are not getting completed.
// See https://github.com/dotnet/corefx/issues/30582 for the VS Debugger issue.
// See also https://github.com/dotnet/machinelearning/issues/4399
_toProduce.TryReceiveAll(out _);
_toConsume.TryReceiveAll(out _);
_toProduceChannel.Reader.ReadAsync();
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
_toConsumeChannel.Reader.ReadAsync();
}

_disposed = true;
base.Dispose(disposing);
}

public static void PostAssert<T>(ITargetBlock<T> target, T item)
public static void PostAssert<T>(Channel<T> target, T item)
{
bool retval = target.Post(item);
bool retval = target.Writer.TryWrite(item);
Contracts.Assert(retval);
}

Expand All @@ -594,10 +594,10 @@ private async Task ProduceAsync()
try
{
int circularIndex = 0;
while (await _toProduce.OutputAvailableAsync().ConfigureAwait(false))
while (await _toProduceChannel.Reader.WaitToReadAsync().ConfigureAwait(false))
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
{
int requested;
if (!_toProduce.TryReceive(out requested))
if (!_toProduceChannel.Reader.TryRead(out requested))
{
// OutputAvailableAsync returned true, but TryReceive returned false -
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
// so loop back around and try again.
Expand All @@ -618,14 +618,14 @@ private async Task ProduceAsync()
if (circularIndex == _pipeIndices.Length)
circularIndex = 0;
}
PostAssert(_toConsume, numRows);
PostAssert(_toConsumeChannel, numRows);
if (numRows < requested)
{
// We've reached the end of the cursor. Send the sentinel, then exit.
// This assumes that the receiver will receive things in Post order
// (so that the sentinel is received, after the last Post).
if (numRows > 0)
PostAssert(_toConsume, 0);
PostAssert(_toConsumeChannel, 0);
return;
}
}
Expand All @@ -634,7 +634,7 @@ private async Task ProduceAsync()
{
_producerTaskException = ex;
// Send the sentinel in this case as well, the field will be checked.
PostAssert(_toConsume, 0);
PostAssert(_toConsumeChannel, 0);
}
}

Expand All @@ -652,14 +652,14 @@ protected override bool MoveNextCore()
// We should let the producer know it can give us more stuff.
// It is possible for int values to be sent beyond the
// end of the sentinel, but we suppose this is irrelevant.
PostAssert(_toProduce, _deadCount);
PostAssert(_toProduceChannel, _deadCount);
_deadCount = 0;
}

while (_liveCount < _poolRows && !_doneConsuming)
{
// We are under capacity. Try to get some more.
int got = _toConsume.Receive();
_toConsumeChannel.Reader.TryRead(out int got);
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
if (got == 0)
{
// We've reached the end sentinel. There's no reason
Expand Down
13 changes: 7 additions & 6 deletions src/Microsoft.ML.Sweeper/AsyncSweeper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
Expand Down Expand Up @@ -168,7 +168,7 @@ public sealed class Options
private readonly object _lock;
private readonly CancellationTokenSource _cts;

private readonly BufferBlock<ParameterSetWithId> _paramQueue;
private readonly Channel<ParameterSetWithId> _paramChannel;
private readonly int _relaxation;
private readonly ISweeper _baseSweeper;
private readonly IHost _host;
Expand Down Expand Up @@ -208,7 +208,8 @@ public DeterministicSweeperAsync(IHostEnvironment env, Options options)
_lock = new object();
_results = new List<IRunResult>();
_nullRuns = new HashSet<int>();
_paramQueue = new BufferBlock<ParameterSetWithId>();
_paramChannel = Channel.CreateUnbounded<ParameterSetWithId>(
new UnboundedChannelOptions { SingleWriter = true });

PrepareNextBatch(null);
}
Expand All @@ -220,12 +221,12 @@ private void PrepareNextBatch(IEnumerable<IRunResult> results)
if (Utils.Size(paramSets) == 0)
{
// Mark the queue as completed.
_paramQueue.Complete();
_paramChannel.Writer.Complete();
return;
}
// Assign an id to each ParameterSet and enque it.
foreach (var paramSet in paramSets)
_paramQueue.Post(new ParameterSetWithId(_numGenerated++, paramSet));
_paramChannel.Writer.TryWrite(new ParameterSetWithId(_numGenerated++, paramSet));
EnsureResultsSize();
}

Expand Down Expand Up @@ -278,7 +279,7 @@ public async Task<ParameterSetWithId> ProposeAsync()
return null;
try
{
return await _paramQueue.ReceiveAsync(_cts.Token);
return await _paramChannel.Reader.ReadAsync(_cts.Token);
}
catch (InvalidOperationException)
{
Expand Down
4 changes: 4 additions & 0 deletions src/Microsoft.ML.Sweeper/Microsoft.ML.Sweeper.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="System.Threading.Channels" Version="$(SystemThreadingChannelsPackageVersion)" />
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
Expand Down