Skip to content

Commit

Permalink
Fixed issue that deadlocked DataLoader key batches (#7437)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib authored Sep 6, 2024
1 parent 8929302 commit 0031a06
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 13 deletions.
2 changes: 2 additions & 0 deletions src/GreenDonut/src/Core/Batch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ internal class Batch<TKey> where TKey : notnull
private readonly List<TKey> _keys = [];
private readonly Dictionary<TKey, IPromise> _items = new();

public bool IsScheduled { get; set; }

public int Size => _keys.Count;

public IReadOnlyList<TKey> Keys => _keys;
Expand Down
35 changes: 22 additions & 13 deletions src/GreenDonut/src/Core/DataLoaderBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,7 @@ protected internal IBatchScheduler BatchScheduler
protected internal DataLoaderOptions Options
=> new()
{
MaxBatchSize = _maxBatchSize,
Cache = Cache,
DiagnosticEvents = _diagnosticEvents,
CancellationToken = _ct,
MaxBatchSize = _maxBatchSize, Cache = Cache, DiagnosticEvents = _diagnosticEvents, CancellationToken = _ct,
};

/// <inheritdoc />
Expand Down Expand Up @@ -180,12 +177,12 @@ protected internal DataLoaderOptions Options
{
Initialize();
}
}

// we dispatch after everything is enqueued.
if (_currentBatch is not null)
{
_batchScheduler.Schedule(() => DispatchBatchAsync(_currentBatch, _ct));
// we dispatch after everything is enqueued.
if (_currentBatch is { IsScheduled: false })
{
ScheduleBatch(_currentBatch);
}
}

return WhenAll();
Expand Down Expand Up @@ -280,13 +277,13 @@ public IDataLoader Branch<TState>(
throw new ArgumentNullException(nameof(createBranch));
}

if(!AllowBranching)
if (!AllowBranching)
{
throw new InvalidOperationException(
"Branching is not allowed for this DataLoader.");
}

if(!_branches.TryGetValue(key, out var branch))
if (!_branches.TryGetValue(key, out var branch))
{
lock (_sync)
{
Expand Down Expand Up @@ -401,20 +398,32 @@ async ValueTask StartDispatchingAsync()
return _currentBatch.GetOrCreatePromise<TValue?>(key, allowCachePropagation);
}

// if there is a current batch and if that current batch was not scheduled for efficiency reasons
// we will schedule it before issuing a new batch.
if (!(_currentBatch?.IsScheduled ?? true))
{
ScheduleBatch(_currentBatch);
}

var newBatch = BatchPool<TKey>.Shared.Get();
var newPromise = newBatch.GetOrCreatePromise<TValue?>(key, allowCachePropagation);

// set the batch before enqueueing to avoid concurrency issues.
_currentBatch = newBatch;
if (scheduleOnNewBatch)
{
_batchScheduler.Schedule(() => DispatchBatchAsync(newBatch, _ct));
ScheduleBatch(newBatch);
}

return newPromise;
}

// ReSharper restore InconsistentlySynchronizedField
private void ScheduleBatch(Batch<TKey> batch)
{
batch.IsScheduled = true;
_batchScheduler.Schedule(() => DispatchBatchAsync(batch, _ct));
}

private void SetSingleResult(
Promise<TValue?> promise,
TKey key,
Expand Down
74 changes: 74 additions & 0 deletions src/GreenDonut/test/Core.Tests/DataLoaderListBatchTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using Microsoft.Extensions.DependencyInjection;
using Xunit;

namespace GreenDonut;

public static class DataLoaderListBatchTests
{
[Fact]
public static async Task Overflow_InternalBatch_Async()
{
// arrange
using var cts = new CancellationTokenSource(5000);
var services = new ServiceCollection()
.AddDataLoader<TestDataLoader>()
.BuildServiceProvider();
var dataLoader = services.GetRequiredService<TestDataLoader>();

// act
var result = await dataLoader.LoadAsync(
Enumerable.Range(0, 5000).ToArray(),
CancellationToken.None);

// assert
Assert.Equal(5000, result.Count);
}

[Fact]
public static async Task Ensure_Multiple_Large_Batches_Can_Be_Enqueued_Concurrently_Async()
{
// arrange
using var cts = new CancellationTokenSource(5000);
var ct = cts.Token;
var services = new ServiceCollection()
.AddDataLoader<TestDataLoader>()
.BuildServiceProvider();
var dataLoader = services.GetRequiredService<TestDataLoader>();

// act
List<Task> tasks = new();
foreach (var _ in Enumerable.Range(0, 10))
{
tasks.Add(
Task.Run(
async () =>
{
var result = await dataLoader.LoadAsync(Enumerable.Range(0, 5000).ToArray(), ct);
// assert
Assert.Equal(5000, result.Count);
},
ct));
}

await Task.WhenAll(tasks);
}


public sealed class TestDataLoader(
IBatchScheduler batchScheduler,
DataLoaderOptions options)
: BatchDataLoader<int, int[]>(batchScheduler, options)
{
protected override async Task<IReadOnlyDictionary<int, int[]>> LoadBatchAsync(
IReadOnlyList<int> runNumbers,
CancellationToken cancellationToken)
{
await Task.Delay(300, cancellationToken).ConfigureAwait(false);

return runNumbers
.Select(t => (t, Enumerable.Range(0, 500)))
.ToDictionary(t => t.Item1, t => t.Item2.ToArray());
}
}
}

0 comments on commit 0031a06

Please sign in to comment.