Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Channel Instead of BufferBlock #5123

Merged
merged 14 commits into from
Jul 8, 2020
2 changes: 1 addition & 1 deletion build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<SystemCollectionsImmutableVersion>1.5.0</SystemCollectionsImmutableVersion>
<SystemMemoryVersion>4.5.1</SystemMemoryVersion>
<SystemReflectionEmitLightweightPackageVersion>4.3.0</SystemReflectionEmitLightweightPackageVersion>
<SystemThreadingTasksDataflowPackageVersion>4.8.0</SystemThreadingTasksDataflowPackageVersion>
<SystemThreadingChannelsPackageVersion>4.7.1</SystemThreadingChannelsPackageVersion>
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
</PropertyGroup>

<!-- Other/Non-Core Product Dependencies -->
Expand Down
2 changes: 1 addition & 1 deletion pkg/Microsoft.ML/Microsoft.ML.nupkgproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="$(SystemReflectionEmitLightweightPackageVersion)" />
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="$(SystemThreadingTasksDataflowPackageVersion)" />
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
<PackageReference Include="System.CodeDom" Version="$(SystemCodeDomPackageVersion)" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
<PackageReference Include="System.Threading.Channels" Version="$(SystemThreadingChannelsPackageVersion)" />
</ItemGroup>

<ItemGroup>
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
<ItemGroup>
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="$(SystemThreadingTasksDataflowPackageVersion)" />
<PackageReference Include="System.Threading.Channels" Version="$(SystemThreadingChannelsPackageVersion)" />
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="$(SystemReflectionEmitLightweightPackageVersion)" />
</ItemGroup>

Expand Down
88 changes: 34 additions & 54 deletions src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -486,13 +486,12 @@ private static readonly FuncInstanceMethodInfo1<Cursor, int, Delegate> _createGe
private int _liveCount;
private bool _doneConsuming;

private readonly BufferBlock<int> _toProduce;
private readonly BufferBlock<int> _toConsume;
private readonly Channel<int> _toProduceChannel;
private readonly Channel<int> _toConsumeChannel;
private readonly Task _producerTask;
private Exception _producerTaskException;

private readonly int[] _colToActivesIndex;
private bool _disposed;

public override DataViewSchema Schema => _input.Schema;

Expand Down Expand Up @@ -541,46 +540,20 @@ public Cursor(IChannelProvider provider, int poolRows, DataViewRowCursor input,
_liveCount = 1;

// Set up the producer worker.
_toConsume = new BufferBlock<int>();
_toProduce = new BufferBlock<int>();
_toConsumeChannel = Channel.CreateUnbounded<int>(new UnboundedChannelOptions { SingleWriter = true });
_toProduceChannel = Channel.CreateUnbounded<int>(new UnboundedChannelOptions { SingleWriter = true });
// First request the pool - 1 + block size rows, to get us going.
PostAssert(_toProduce, _poolRows - 1 + _blockSize);
PostAssert(_toProduceChannel, _poolRows - 1 + _blockSize);
// Queue up the remaining capacity.
for (int i = 1; i < _bufferDepth; ++i)
PostAssert(_toProduce, _blockSize);
PostAssert(_toProduceChannel, _blockSize);

_producerTask = ProduceAsync();
}

protected override void Dispose(bool disposing)
public static void PostAssert<T>(Channel<T> target, T item)
{
if (_disposed)
return;

if (disposing)
{
_toProduce.Complete();
_producerTask.Wait();

// Complete the consumer after the producerTask has finished, since producerTask could
// have posted more items to _toConsume.
_toConsume.Complete();

// Drain both BufferBlocks - this prevents what appears to be memory leaks when using the VS Debugger
// because if a BufferBlock still contains items, its underlying Tasks are not getting completed.
// See https://github.com/dotnet/corefx/issues/30582 for the VS Debugger issue.
// See also https://github.com/dotnet/machinelearning/issues/4399
_toProduce.TryReceiveAll(out _);
_toConsume.TryReceiveAll(out _);
}

_disposed = true;
base.Dispose(disposing);
}

public static void PostAssert<T>(ITargetBlock<T> target, T item)
{
bool retval = target.Post(item);
bool retval = target.Writer.TryWrite(item);
Contracts.Assert(retval);
}

Expand All @@ -594,12 +567,13 @@ private async Task ProduceAsync()
try
{
int circularIndex = 0;
while (await _toProduce.OutputAvailableAsync().ConfigureAwait(false))
while (await _toProduceChannel.Reader.WaitToReadAsync().ConfigureAwait(false))
jwood803 marked this conversation as resolved.
Show resolved Hide resolved
{
int requested;
if (!_toProduce.TryReceive(out requested))
if (!_toProduceChannel.Reader.TryRead(out requested))
{
// OutputAvailableAsync returned true, but TryReceive returned false -
// The producer Channel's Reader.WaitToReadAsync returned true,
// but the Reader's TryRead returned false -
// so loop back around and try again.
continue;
}
Expand All @@ -618,14 +592,14 @@ private async Task ProduceAsync()
if (circularIndex == _pipeIndices.Length)
circularIndex = 0;
}
PostAssert(_toConsume, numRows);
PostAssert(_toConsumeChannel, numRows);
if (numRows < requested)
{
// We've reached the end of the cursor. Send the sentinel, then exit.
// This assumes that the receiver will receive things in Post order
// (so that the sentinel is received, after the last Post).
if (numRows > 0)
PostAssert(_toConsume, 0);
PostAssert(_toConsumeChannel, 0);
return;
}
}
Expand All @@ -634,7 +608,7 @@ private async Task ProduceAsync()
{
_producerTaskException = ex;
// Send the sentinel in this case as well, the field will be checked.
PostAssert(_toConsume, 0);
PostAssert(_toConsumeChannel, 0);
}
}

Expand All @@ -651,26 +625,32 @@ protected override bool MoveNextCore()
{
// We should let the producer know it can give us more stuff.
// It is possible for int values to be sent beyond the
// end of the sentinel, but we suppose this is irrelevant.
PostAssert(_toProduce, _deadCount);
// end of the Channel, but we suppose this is irrelevant.
PostAssert(_toProduceChannel, _deadCount);
_deadCount = 0;
}

while (_liveCount < _poolRows && !_doneConsuming)
{
// We are under capacity. Try to get some more.
int got = _toConsume.Receive();
if (got == 0)
while (_toConsumeChannel.Reader.WaitToReadAsync().GetAwaiter().GetResult())
{
// We've reached the end sentinel. There's no reason
// to attempt further communication with the producer.
// Check whether something horrible happened.
if (_producerTaskException != null)
throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception");
_doneConsuming = true;
break;
var hasReadItem = _toConsumeChannel.Reader.TryRead(out int got);
if (hasReadItem)
{
if (got == 0)
{
// We've reached the end of the Channel. There's no reason
// to attempt further communication with the producer.
// Check whether something horrible happened.
if (_producerTaskException != null)
throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception");
_doneConsuming = true;
break;
}
_liveCount += got;
}
}
_liveCount += got;
}
if (_liveCount == 0)
return false;
Expand Down
13 changes: 7 additions & 6 deletions src/Microsoft.ML.Sweeper/AsyncSweeper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
Expand Down Expand Up @@ -168,7 +168,7 @@ public sealed class Options
private readonly object _lock;
private readonly CancellationTokenSource _cts;

private readonly BufferBlock<ParameterSetWithId> _paramQueue;
private readonly Channel<ParameterSetWithId> _paramChannel;
private readonly int _relaxation;
private readonly ISweeper _baseSweeper;
private readonly IHost _host;
Expand Down Expand Up @@ -208,7 +208,8 @@ public DeterministicSweeperAsync(IHostEnvironment env, Options options)
_lock = new object();
_results = new List<IRunResult>();
_nullRuns = new HashSet<int>();
_paramQueue = new BufferBlock<ParameterSetWithId>();
_paramChannel = Channel.CreateUnbounded<ParameterSetWithId>(
new UnboundedChannelOptions { SingleWriter = true });

PrepareNextBatch(null);
}
Expand All @@ -220,12 +221,12 @@ private void PrepareNextBatch(IEnumerable<IRunResult> results)
if (Utils.Size(paramSets) == 0)
{
// Mark the queue as completed.
_paramQueue.Complete();
_paramChannel.Writer.Complete();
return;
}
// Assign an id to each ParameterSet and enque it.
foreach (var paramSet in paramSets)
_paramQueue.Post(new ParameterSetWithId(_numGenerated++, paramSet));
_paramChannel.Writer.TryWrite(new ParameterSetWithId(_numGenerated++, paramSet));
EnsureResultsSize();
}

Expand Down Expand Up @@ -278,7 +279,7 @@ public async Task<ParameterSetWithId> ProposeAsync()
return null;
try
{
return await _paramQueue.ReceiveAsync(_cts.Token);
return await _paramChannel.Reader.ReadAsync(_cts.Token);
}
catch (InvalidOperationException)
{
Expand Down
1 change: 0 additions & 1 deletion test/Microsoft.ML.FSharp.Tests/SmokeTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Google.Protobuf.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Newtonsoft.Json.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.CodeDom.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.Threading.Tasks.Dataflow.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.CpuMath.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Data.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Transforms.dll"
Expand Down