Skip to content

Commit 6f8272b

Browse files
authored
Added more overloads for list selects. (#7578)
1 parent 34cb8a9 commit 6f8272b

File tree

10 files changed

+898
-36
lines changed

10 files changed

+898
-36
lines changed

src/GreenDonut/src/Core/DataLoaderFetchContext.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,6 @@ public ISelectorBuilder GetSelector()
152152

153153
// if no selector was found we will just return
154154
// a new default selector builder.
155-
return new DefaultSelectorBuilder<TValue>();
155+
return new DefaultSelectorBuilder();
156156
}
157157
}

src/GreenDonut/src/Core/Projections/DefaultSelectorBuilder.cs

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,14 @@ namespace GreenDonut.Projections;
66
/// <summary>
77
/// A default implementation of the <see cref="ISelectorBuilder"/>.
88
/// </summary>
9-
/// <typeparam name="TValue"></typeparam>
109
[Experimental(Experiments.Projections)]
11-
public sealed class DefaultSelectorBuilder<TValue> : ISelectorBuilder
10+
public sealed class DefaultSelectorBuilder : ISelectorBuilder
1211
{
1312
private List<LambdaExpression>? _selectors;
1413

1514
/// <inheritdoc />
1615
public void Add<T>(Expression<Func<T, T>> selector)
1716
{
18-
if (typeof(T) != typeof(TValue))
19-
{
20-
throw new ArgumentException(
21-
"The projection type must match the DataLoader value type.",
22-
nameof(selector));
23-
}
24-
2517
_selectors ??= new List<LambdaExpression>();
2618
if (!_selectors.Contains(selector))
2719
{
@@ -37,11 +29,6 @@ public void Add<T>(Expression<Func<T, T>> selector)
3729
return null;
3830
}
3931

40-
if (typeof(T) != typeof(TValue))
41-
{
42-
return null;
43-
}
44-
4532
if (_selectors.Count == 1)
4633
{
4734
return (Expression<Func<T, T>>)_selectors[0];

src/GreenDonut/src/Core/Projections/SelectionDataLoaderExtensions.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,18 @@ public static class SelectionDataLoaderExtensions
3232
/// <typeparam name="TValue">
3333
/// The value type.
3434
/// </typeparam>
35+
/// <typeparam name="TElement">
36+
/// The element type.
37+
/// </typeparam>
3538
/// <returns>
3639
/// Returns a branched DataLoader with the selector applied.
3740
/// </returns>
3841
/// <exception cref="ArgumentNullException">
3942
/// Throws if <paramref name="dataLoader"/> is <c>null</c>.
4043
/// </exception>
41-
public static IDataLoader<TKey, TValue> Select<TKey, TValue>(
44+
public static IDataLoader<TKey, TValue> Select<TKey, TValue, TElement>(
4245
this IDataLoader<TKey, TValue> dataLoader,
43-
Expression<Func<TValue, TValue>>? selector)
46+
Expression<Func<TElement, TElement>>? selector)
4447
where TKey : notnull
4548
{
4649
if (dataLoader is null)
@@ -55,7 +58,7 @@ public static IDataLoader<TKey, TValue> Select<TKey, TValue>(
5558

5659
if (dataLoader is ISelectionDataLoader<TKey, TValue>)
5760
{
58-
var context = (DefaultSelectorBuilder<TValue>)dataLoader.ContextData[typeof(ISelectorBuilder).FullName!]!;
61+
var context = (DefaultSelectorBuilder)dataLoader.ContextData[typeof(ISelectorBuilder).FullName!]!;
5962
context.Add(selector);
6063
return dataLoader;
6164
}
@@ -66,12 +69,12 @@ public static IDataLoader<TKey, TValue> Select<TKey, TValue>(
6669
static IDataLoader CreateBranch(
6770
string key,
6871
IDataLoader<TKey, TValue> dataLoader,
69-
Expression<Func<TValue, TValue>> selector)
72+
Expression<Func<TElement, TElement>> selector)
7073
{
7174
var branch = new SelectionDataLoader<TKey, TValue>(
7275
(DataLoaderBase<TKey, TValue>)dataLoader,
7376
key);
74-
var context = new DefaultSelectorBuilder<TValue>();
77+
var context = new DefaultSelectorBuilder();
7578
branch.ContextData = branch.ContextData.SetItem(typeof(ISelectorBuilder).FullName!, context);
7679
context.Add(selector);
7780
return branch;
@@ -134,7 +137,7 @@ public static IDataLoader<TKey, TValue> Include<TKey, TValue>(
134137

135138
var context = dataLoader.GetOrSetState(
136139
typeof(ISelectorBuilder).FullName!,
137-
_ => new DefaultSelectorBuilder<TValue>());
140+
_ => new DefaultSelectorBuilder());
138141
context.Add(Rewrite(includeSelector));
139142
return dataLoader;
140143
}
@@ -235,7 +238,7 @@ public static IQueryable<KeyValueResult<TKey, IEnumerable<TValue>>> Select<T, TK
235238
{
236239
var selectMethod = _selectMethod.MakeGenericMethod(typeof(TValue), typeof(TValue));
237240

238-
list = Expression.Lambda<Func<T, IEnumerable<TValue>>>(
241+
rewrittenList = Expression.Lambda<Func<T, IEnumerable<TValue>>>(
239242
Expression.Call(
240243
selectMethod,
241244
rewrittenList.Body,
@@ -254,7 +257,7 @@ public static IQueryable<KeyValueResult<TKey, IEnumerable<TValue>>> Select<T, TK
254257
Expression.Bind(
255258
typeof(KeyValueResult<TKey, IEnumerable<TValue>>).GetProperty(
256259
nameof(KeyValueResult<TKey, IEnumerable<TValue>>.Value))!,
257-
list.Body)),
260+
rewrittenList.Body)),
258261
parameter);
259262

260263
// lastly we apply the selector expression to the queryable.

src/HotChocolate/Core/src/Execution/Extensions/HotChocolateExecutionSelectionExtensions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public static Expression<Func<TValue, TValue>> AsSelector<TValue>(
5050

5151
if ((flags & FieldFlags.Connection) == FieldFlags.Connection)
5252
{
53-
var builder = new DefaultSelectorBuilder<TValue>();
53+
var builder = new DefaultSelectorBuilder();
5454
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
5555
var count = GetConnectionSelections(selection, buffer);
5656
for (var i = 0; i < count; i++)
@@ -63,7 +63,7 @@ public static Expression<Func<TValue, TValue>> AsSelector<TValue>(
6363

6464
if ((flags & FieldFlags.CollectionSegment) == FieldFlags.CollectionSegment)
6565
{
66-
var builder = new DefaultSelectorBuilder<TValue>();
66+
var builder = new DefaultSelectorBuilder();
6767
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
6868
var count = GetCollectionSelections(selection, buffer);
6969
for (var i = 0; i < count; i++)

src/HotChocolate/Core/src/Execution/Projections/HotChocolateExecutionDataLoaderExtensions.cs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,118 @@ public static IDataLoader<TKey, TValue> Select<TKey, TValue>(
4848
return dataLoader.Select(expression);
4949
}
5050

51+
/// <summary>
52+
/// Selects the fields that where selected in the GraphQL selection tree.
53+
/// </summary>
54+
/// <param name="dataLoader">
55+
/// The data loader.
56+
/// </param>
57+
/// <param name="selection">
58+
/// The selection that shall be applied to the data loader.
59+
/// </param>
60+
/// <typeparam name="TKey">
61+
/// The key type.
62+
/// </typeparam>
63+
/// <typeparam name="TValue">
64+
/// The value type.
65+
/// </typeparam>
66+
/// <returns>
67+
/// Returns a new data loader that applies the selection.
68+
/// </returns>
69+
public static IDataLoader<TKey, TValue[]> Select<TKey, TValue>(
70+
this IDataLoader<TKey, TValue[]> dataLoader,
71+
ISelection selection)
72+
where TKey : notnull
73+
where TValue : notnull
74+
{
75+
var expression = selection.AsSelector<TValue>();
76+
return dataLoader.Select(expression);
77+
}
78+
79+
/// <summary>
80+
/// Selects the fields that where selected in the GraphQL selection tree.
81+
/// </summary>
82+
/// <param name="dataLoader">
83+
/// The data loader.
84+
/// </param>
85+
/// <param name="selection">
86+
/// The selection that shall be applied to the data loader.
87+
/// </param>
88+
/// <typeparam name="TKey">
89+
/// The key type.
90+
/// </typeparam>
91+
/// <typeparam name="TValue">
92+
/// The value type.
93+
/// </typeparam>
94+
/// <returns>
95+
/// Returns a new data loader that applies the selection.
96+
/// </returns>
97+
public static IDataLoader<TKey, ICollection<TValue>> Select<TKey, TValue>(
98+
this IDataLoader<TKey, ICollection<TValue>> dataLoader,
99+
ISelection selection)
100+
where TKey : notnull
101+
where TValue : notnull
102+
{
103+
var expression = selection.AsSelector<TValue>();
104+
return dataLoader.Select(expression);
105+
}
106+
107+
/// <summary>
108+
/// Selects the fields that where selected in the GraphQL selection tree.
109+
/// </summary>
110+
/// <param name="dataLoader">
111+
/// The data loader.
112+
/// </param>
113+
/// <param name="selection">
114+
/// The selection that shall be applied to the data loader.
115+
/// </param>
116+
/// <typeparam name="TKey">
117+
/// The key type.
118+
/// </typeparam>
119+
/// <typeparam name="TValue">
120+
/// The value type.
121+
/// </typeparam>
122+
/// <returns>
123+
/// Returns a new data loader that applies the selection.
124+
/// </returns>
125+
public static IDataLoader<TKey, IEnumerable<TValue>> Select<TKey, TValue>(
126+
this IDataLoader<TKey, IEnumerable<TValue>> dataLoader,
127+
ISelection selection)
128+
where TKey : notnull
129+
where TValue : notnull
130+
{
131+
var expression = selection.AsSelector<TValue>();
132+
return dataLoader.Select(expression);
133+
}
134+
135+
/// <summary>
136+
/// Selects the fields that where selected in the GraphQL selection tree.
137+
/// </summary>
138+
/// <param name="dataLoader">
139+
/// The data loader.
140+
/// </param>
141+
/// <param name="selection">
142+
/// The selection that shall be applied to the data loader.
143+
/// </param>
144+
/// <typeparam name="TKey">
145+
/// The key type.
146+
/// </typeparam>
147+
/// <typeparam name="TValue">
148+
/// The value type.
149+
/// </typeparam>
150+
/// <returns>
151+
/// Returns a new data loader that applies the selection.
152+
/// </returns>
153+
public static IDataLoader<TKey, List<TValue>> Select<TKey, TValue>(
154+
this IDataLoader<TKey, List<TValue>> dataLoader,
155+
ISelection selection)
156+
where TKey : notnull
157+
where TValue : notnull
158+
{
159+
var expression = selection.AsSelector<TValue>();
160+
return dataLoader.Select(expression);
161+
}
162+
51163
/// <summary>
52164
/// Selects the fields that where selected in the GraphQL selection tree.
53165
/// </summary>

src/HotChocolate/Core/src/Types.Analyzers/FileBuilders/DataLoaderFileBuilder.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,7 @@ public void WriteDataLoaderLoadMethod(
264264
parameter.StateKey);
265265
_writer.IncreaseIndent();
266266
_writer.WriteIndentedLine(
267-
"?? new global::GreenDonut.Projections.DefaultSelectorBuilder<{0}>();",
268-
value.ToFullyQualified());
267+
"?? new global::GreenDonut.Projections.DefaultSelectorBuilder();");
269268
_writer.DecreaseIndent();
270269
}
271270
else if (parameter.Kind is DataLoaderParameterKind.PagingArguments)

src/HotChocolate/Core/test/Execution.Tests/Projections/ProjectableDataLoaderTests.cs

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,38 @@ public async Task Project_Key_To_Collection_Expression()
728728
.MatchMarkdownSnapshot();
729729
}
730730

731+
[Fact]
732+
public async Task Project_Key_To_Collection_Expression_Integration()
733+
{
734+
// Arrange
735+
var queries = new List<string>();
736+
var connectionString = CreateConnectionString();
737+
await CatalogContext.SeedAsync(connectionString);
738+
739+
// Act
740+
var result = await new ServiceCollection()
741+
.AddScoped(_ => queries)
742+
.AddTransient(_ => new CatalogContext(connectionString))
743+
.AddGraphQLServer()
744+
.AddQueryType<BrandsQuery>()
745+
.AddTypeExtension(typeof(BrandListExtensions))
746+
.ExecuteRequestAsync(
747+
"""
748+
{
749+
brandById(id: 1) {
750+
products {
751+
name
752+
}
753+
}
754+
}
755+
""");
756+
// Assert
757+
Snapshot.Create()
758+
.AddSql(queries)
759+
.Add(result, "Result")
760+
.MatchMarkdownSnapshot();
761+
}
762+
731763
public class Query
732764
{
733765
public async Task<Brand?> GetBrandByIdAsync(
@@ -792,6 +824,19 @@ public class NodeQuery
792824
}
793825
}
794826

827+
public class BrandsQuery
828+
{
829+
public async Task<IEnumerable<Brand>> GetBrandByIdAsync(
830+
int id,
831+
ISelection selection,
832+
CatalogContext context,
833+
CancellationToken cancellationToken)
834+
=> await context.Brands
835+
.Select(selection.AsSelector<Brand>())
836+
.Take(2)
837+
.ToListAsync(cancellationToken);
838+
}
839+
795840
[ExtendObjectType<Brand>]
796841
public class BrandExtensions
797842
{
@@ -810,6 +855,18 @@ public string GetDetails(
810855
=> "Brand Name:" + brand.Name;
811856
}
812857

858+
[ExtendObjectType<Brand>]
859+
public class BrandListExtensions
860+
{
861+
[BindMember(nameof(Brand.Products))]
862+
public async Task<IEnumerable<Product>?> GetProductsAsync(
863+
[Parent] Brand brand,
864+
ProductByBrandIdDataLoader2 productByBrandId,
865+
ISelection selection,
866+
CancellationToken cancellationToken)
867+
=> await productByBrandId.Select(selection).LoadAsync(brand.Id, cancellationToken);
868+
}
869+
813870
public class BrandWithRequirementType : ObjectType<Brand>
814871
{
815872
protected override void Configure(IObjectTypeDescriptor<Brand> descriptor)
@@ -934,7 +991,7 @@ protected override async Task<IReadOnlyDictionary<int, Product[]>> LoadBatchAsyn
934991
CancellationToken cancellationToken)
935992
{
936993
var catalogContext = services.GetRequiredService<CatalogContext>();
937-
var selector = new DefaultSelectorBuilder<Product>();
994+
var selector = new DefaultSelectorBuilder();
938995
selector.Add<Product>(t => new Product { Name = t.Name });
939996

940997
var query = catalogContext.Brands
@@ -951,6 +1008,35 @@ protected override async Task<IReadOnlyDictionary<int, Product[]>> LoadBatchAsyn
9511008
return x;
9521009
}
9531010
}
1011+
1012+
public class ProductByBrandIdDataLoader2(
1013+
IServiceProvider services,
1014+
List<string> queries,
1015+
IBatchScheduler batchScheduler,
1016+
DataLoaderOptions options)
1017+
: StatefulBatchDataLoader<int, Product[]>(batchScheduler, options)
1018+
{
1019+
protected override async Task<IReadOnlyDictionary<int, Product[]>> LoadBatchAsync(
1020+
IReadOnlyList<int> keys,
1021+
DataLoaderFetchContext<Product[]> context,
1022+
CancellationToken cancellationToken)
1023+
{
1024+
var catalogContext = services.GetRequiredService<CatalogContext>();
1025+
1026+
var query = catalogContext.Brands
1027+
.Where(t => keys.Contains(t.Id))
1028+
.Select(t => t.Id, t => t.Products, context.GetSelector());
1029+
1030+
lock (queries)
1031+
{
1032+
queries.Add(query.ToQueryString());
1033+
}
1034+
1035+
var x = await query.ToDictionaryAsync(t => t.Key, t => t.Value.ToArray(), cancellationToken);
1036+
1037+
return x;
1038+
}
1039+
}
9541040
}
9551041

9561042
file static class Extensions

0 commit comments

Comments
 (0)