Skip to content

Commit

Permalink
Backported inline DataLoader fix from 13 (#4916)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib authored Apr 3, 2022
1 parent ee6184e commit 4b991aa
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 10 deletions.
15 changes: 7 additions & 8 deletions src/GreenDonut/src/Core/DataLoaderBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ public abstract partial class DataLoaderBase<TKey, TValue>
private readonly object _sync = new();
private readonly CancellationTokenSource _disposeTokenSource = new();
private readonly IBatchScheduler _batchScheduler;
private readonly string _cacheKeyType;
private readonly int _maxBatchSize;
private readonly ITaskCache? _cache;
private readonly TaskCacheOwner? _cacheOwner;
Expand Down Expand Up @@ -69,7 +68,7 @@ protected DataLoaderBase(IBatchScheduler batchScheduler, DataLoaderOptions? opti

_batchScheduler = batchScheduler;
_maxBatchSize = options.MaxBatchSize;
_cacheKeyType = GetCacheKeyType(GetType());
CacheKeyType = GetCacheKeyType(GetType());
}

/// <summary>
Expand All @@ -80,7 +79,7 @@ protected DataLoaderBase(IBatchScheduler batchScheduler, DataLoaderOptions? opti
/// <summary>
/// Gets the cache key type for this DataLoader.
/// </summary>
protected string CacheKeyType => _cacheKeyType;
protected virtual string CacheKeyType { get; }

/// <inheritdoc />
public Task<TValue> LoadAsync(TKey key, CancellationToken cancellationToken = default)
Expand All @@ -91,7 +90,7 @@ public Task<TValue> LoadAsync(TKey key, CancellationToken cancellationToken = de
}

var cached = true;
TaskCacheKey cacheKey = new(_cacheKeyType, key);
TaskCacheKey cacheKey = new(CacheKeyType, key);

lock (_sync)
{
Expand Down Expand Up @@ -154,7 +153,7 @@ void InitializeWithCache()

cached = true;
currentKey = key;
TaskCacheKey cacheKey = new(_cacheKeyType, key);
TaskCacheKey cacheKey = new(CacheKeyType, key);
Task<TValue> cachedTask = _cache.GetOrAddTask(cacheKey, CreatePromise);

if (cached)
Expand Down Expand Up @@ -197,7 +196,7 @@ public void Remove(TKey key)

if (_cache is not null)
{
TaskCacheKey cacheKey = new(_cacheKeyType, key);
TaskCacheKey cacheKey = new(CacheKeyType, key);
_cache.TryRemove(cacheKey);
}
}
Expand All @@ -217,7 +216,7 @@ public void Set(TKey key, Task<TValue> value)

if (_cache is not null)
{
TaskCacheKey cacheKey = new(_cacheKeyType, key);
TaskCacheKey cacheKey = new(CacheKeyType, key);
_cache.TryAdd(cacheKey, value);
}
}
Expand All @@ -233,7 +232,7 @@ private void BatchOperationFailed(
{
if (_cache is not null)
{
TaskCacheKey cacheKey = new(_cacheKeyType, key);
TaskCacheKey cacheKey = new(CacheKeyType, key);
_cache.TryRemove(cacheKey);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public static IDataLoader<TKey, TValue> BatchDataLoader<TKey, TValue>(
IDataLoaderRegistry reg = services.GetRequiredService<IDataLoaderRegistry>();
FetchBatchDataLoader<TKey, TValue> Loader()
=> new(
key ?? "default",
fetch,
services.GetRequiredService<IBatchScheduler>(),
services.GetRequiredService<DataLoaderOptions>());
Expand Down Expand Up @@ -80,6 +81,7 @@ public static IDataLoader<TKey, TValue[]> GroupDataLoader<TKey, TValue>(
IDataLoaderRegistry reg = services.GetRequiredService<IDataLoaderRegistry>();
FetchGroupedDataLoader<TKey, TValue> Loader()
=> new(
key ?? "default",
fetch,
services.GetRequiredService<IBatchScheduler>(),
services.GetRequiredService<DataLoaderOptions>());
Expand Down Expand Up @@ -125,7 +127,7 @@ public static IDataLoader<TKey, TValue> CacheDataLoader<TKey, TValue>(
IServiceProvider services = context.Services;
IDataLoaderRegistry reg = services.GetRequiredService<IDataLoaderRegistry>();
FetchCacheDataLoader<TKey, TValue> Loader()
=> new(fetch, services.GetRequiredService<DataLoaderOptions>());
=> new(key ?? "default", fetch, services.GetRequiredService<DataLoaderOptions>());

return key is null
? reg.GetOrRegister(Loader)
Expand Down
5 changes: 5 additions & 0 deletions src/HotChocolate/Core/src/Fetching/FetchBatchDataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@ internal sealed class FetchBatchDataLoader<TKey, TValue>
private readonly FetchBatch<TKey, TValue> _fetch;

public FetchBatchDataLoader(
string key,
FetchBatch<TKey, TValue> fetch,
IBatchScheduler batchScheduler,
DataLoaderOptions options)
: base(batchScheduler, options)
{
_fetch = fetch ?? throw new ArgumentNullException(nameof(fetch));
CacheKeyType = $"{GetCacheKeyType(GetType())}-{key}";
}

protected override string CacheKeyType { get; }


protected override Task<IReadOnlyDictionary<TKey, TValue>> LoadBatchAsync(
IReadOnlyList<TKey> keys,
CancellationToken cancellationToken) =>
Expand Down
8 changes: 7 additions & 1 deletion src/HotChocolate/Core/src/Fetching/FetchCacheDataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@ internal sealed class FetchCacheDataLoader<TKey, TValue>
{
private readonly FetchCacheCt<TKey, TValue> _fetch;

public FetchCacheDataLoader(FetchCacheCt<TKey, TValue> fetch, DataLoaderOptions options)
public FetchCacheDataLoader(
string key,
FetchCacheCt<TKey, TValue> fetch,
DataLoaderOptions options)
: base(options)
{
_fetch = fetch ?? throw new ArgumentNullException(nameof(fetch));
CacheKeyType = $"{GetCacheKeyType(GetType())}-{key}";
}

protected override string CacheKeyType { get; }

protected override Task<TValue> LoadSingleAsync(
TKey key,
CancellationToken cancellationToken) =>
Expand Down
4 changes: 4 additions & 0 deletions src/HotChocolate/Core/src/Fetching/FetchGroupedDataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@ internal sealed class FetchGroupedDataLoader<TKey, TValue>
private readonly FetchGroup<TKey, TValue> _fetch;

public FetchGroupedDataLoader(
string key,
FetchGroup<TKey, TValue> fetch,
IBatchScheduler batchScheduler,
DataLoaderOptions options)
: base(batchScheduler, options)
{
_fetch = fetch ?? throw new ArgumentNullException(nameof(fetch));
CacheKeyType = $"{GetCacheKeyType(GetType())}-{key}";
}

protected override string CacheKeyType { get; }

protected override Task<ILookup<TKey, TValue>> LoadGroupedBatchAsync(
IReadOnlyList<TKey> keys,
CancellationToken cancellationToken) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\Core\HotChocolate.Core.csproj" />
<ProjectReference Include="..\..\src\Fetching\HotChocolate.Fetching.csproj" />
<ProjectReference Include="..\Utilities\HotChocolate.Tests.Utilities.csproj" />
</ItemGroup>

<!--For Visual Studio for Mac Test Explorer we need this reference here-->
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using HotChocolate.Execution;
using HotChocolate.Resolvers;
using HotChocolate.Tests;
using HotChocolate.Types;
using Microsoft.Extensions.DependencyInjection;
using Xunit;

namespace HotChocolate.Fetching;

public class InlineBatchDataLoaderTests
{
[Fact]
public async Task LoadWithDifferentDataLoader()
{
// arrange
IRequestExecutor executor =
await new ServiceCollection()
.AddGraphQL()
.AddQueryType<Query>()
.BuildRequestExecutorAsync();

// act
var result1 = await executor.ExecuteAsync("{ byKey(key: \"abc\") }").ToJsonAsync();
var result2 = await executor.ExecuteAsync("{ byKey(key: \"def\") }").ToJsonAsync();

// assert
Assert.NotEqual(result1, result2);
}

[Fact]
public async Task LoadWithSingleKeyDataLoader()
{
// arrange
IRequestExecutor executor =
await new ServiceCollection()
.AddGraphQL()
.AddQueryType<Query>()
.BuildRequestExecutorAsync();

// act
var result1 = await executor.ExecuteAsync("{ byKey(key: \"abc\") }").ToJsonAsync();
var result2 = await executor.ExecuteAsync("{ byKey(key: \"abc\") }").ToJsonAsync();

// assert
Assert.Equal(result1, result2);
}

public class Query
{
public async Task<string> GetByKey(string key, IResolverContext context)
{
return await context
.BatchDataLoader<string, string>(
(keys, _) =>
Task.FromResult<IReadOnlyDictionary<string, string>>(
keys.ToDictionary(t => t, _ => key)),
key)
.LoadAsync("abc", context.RequestAborted);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using System.Threading;
using System.Threading.Tasks;
using HotChocolate.Execution;
using HotChocolate.Resolvers;
using HotChocolate.Tests;
using HotChocolate.Types;
using Microsoft.Extensions.DependencyInjection;
using Xunit;

namespace HotChocolate.Fetching;

public class InlineCacheDataLoaderTests
{
[Fact]
public async Task LoadWithDifferentDataLoader()
{
// arrange
IRequestExecutor executor =
await new ServiceCollection()
.AddGraphQL()
.AddQueryType<Query>()
.BuildRequestExecutorAsync();

// act
var result1 = await executor.ExecuteAsync("{ byKey(key: \"abc\") }").ToJsonAsync();
var result2 = await executor.ExecuteAsync("{ byKey(key: \"def\") }").ToJsonAsync();

// assert
Assert.NotEqual(result1, result2);
}

[Fact]
public async Task LoadWithSingleKeyDataLoader()
{
// arrange
IRequestExecutor executor =
await new ServiceCollection()
.AddGraphQL()
.AddQueryType<Query>()
.BuildRequestExecutorAsync();

// act
var result1 = await executor.ExecuteAsync("{ byKey(key: \"abc\") }").ToJsonAsync();
var result2 = await executor.ExecuteAsync("{ byKey(key: \"abc\") }").ToJsonAsync();

// assert
Assert.Equal(result1, result2);
}

public class Query
{
public async Task<string> GetByKey(string key, IResolverContext context)
{
return await context
.CacheDataLoader<string, string>(
(_, _) => Task.FromResult(key),
key)
.LoadAsync("abc", context.RequestAborted);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using System.Linq;
using System.Threading.Tasks;
using HotChocolate.Execution;
using HotChocolate.Resolvers;
using HotChocolate.Tests;
using HotChocolate.Types;
using Microsoft.Extensions.DependencyInjection;
using Xunit;

namespace HotChocolate.Fetching;

public class InlineGroupDataLoaderTests
{
[Fact]
public async Task LoadWithDifferentDataLoader()
{
// arrange
IRequestExecutor executor =
await new ServiceCollection()
.AddGraphQL()
.AddQueryType<Query>()
.BuildRequestExecutorAsync();

// act
var result1 = await executor.ExecuteAsync("{ byKey(key: \"abc\") }").ToJsonAsync();
var result2 = await executor.ExecuteAsync("{ byKey(key: \"def\") }").ToJsonAsync();

// assert
Assert.NotEqual(result1, result2);
}

[Fact]
public async Task LoadWithSingleKeyDataLoader()
{
// arrange
IRequestExecutor executor =
await new ServiceCollection()
.AddGraphQL()
.AddQueryType<Query>()
.BuildRequestExecutorAsync();

// act
var result1 = await executor.ExecuteAsync("{ byKey(key: \"abc\") }").ToJsonAsync();
var result2 = await executor.ExecuteAsync("{ byKey(key: \"abc\") }").ToJsonAsync();

// assert
Assert.Equal(result1, result2);
}

public class Query
{
public async Task<string[]> GetByKey(string key, IResolverContext context)
{
return await context
.GroupDataLoader<string, string>(
(keys, _) => Task.FromResult(keys.ToLookup(t => t, _ => key)),
key)
.LoadAsync("abc", context.RequestAborted);
}
}
}

0 comments on commit 4b991aa

Please sign in to comment.