Skip to content

Commit

Permalink
Added more overloads for list selects. (#7578)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib authored Oct 8, 2024
1 parent 34cb8a9 commit 6f8272b
Show file tree
Hide file tree
Showing 10 changed files with 898 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/GreenDonut/src/Core/DataLoaderFetchContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,6 @@ public ISelectorBuilder GetSelector()

// if no selector was found we will just return
// a new default selector builder.
return new DefaultSelectorBuilder<TValue>();
return new DefaultSelectorBuilder();
}
}
15 changes: 1 addition & 14 deletions src/GreenDonut/src/Core/Projections/DefaultSelectorBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,14 @@ namespace GreenDonut.Projections;
/// <summary>
/// A default implementation of the <see cref="ISelectorBuilder"/>.
/// </summary>
/// <typeparam name="TValue"></typeparam>
[Experimental(Experiments.Projections)]
public sealed class DefaultSelectorBuilder<TValue> : ISelectorBuilder
public sealed class DefaultSelectorBuilder : ISelectorBuilder
{
private List<LambdaExpression>? _selectors;

/// <inheritdoc />
public void Add<T>(Expression<Func<T, T>> selector)
{
if (typeof(T) != typeof(TValue))
{
throw new ArgumentException(
"The projection type must match the DataLoader value type.",
nameof(selector));
}

_selectors ??= new List<LambdaExpression>();
if (!_selectors.Contains(selector))
{
Expand All @@ -37,11 +29,6 @@ public void Add<T>(Expression<Func<T, T>> selector)
return null;
}

if (typeof(T) != typeof(TValue))
{
return null;
}

if (_selectors.Count == 1)
{
return (Expression<Func<T, T>>)_selectors[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,18 @@ public static class SelectionDataLoaderExtensions
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <typeparam name="TElement">
/// The element type.
/// </typeparam>
/// <returns>
/// Returns a branched DataLoader with the selector applied.
/// </returns>
/// <exception cref="ArgumentNullException">
/// Throws if <paramref name="dataLoader"/> is <c>null</c>.
/// </exception>
public static IDataLoader<TKey, TValue> Select<TKey, TValue>(
public static IDataLoader<TKey, TValue> Select<TKey, TValue, TElement>(
this IDataLoader<TKey, TValue> dataLoader,
Expression<Func<TValue, TValue>>? selector)
Expression<Func<TElement, TElement>>? selector)
where TKey : notnull
{
if (dataLoader is null)
Expand All @@ -55,7 +58,7 @@ public static IDataLoader<TKey, TValue> Select<TKey, TValue>(

if (dataLoader is ISelectionDataLoader<TKey, TValue>)
{
var context = (DefaultSelectorBuilder<TValue>)dataLoader.ContextData[typeof(ISelectorBuilder).FullName!]!;
var context = (DefaultSelectorBuilder)dataLoader.ContextData[typeof(ISelectorBuilder).FullName!]!;
context.Add(selector);
return dataLoader;
}
Expand All @@ -66,12 +69,12 @@ public static IDataLoader<TKey, TValue> Select<TKey, TValue>(
static IDataLoader CreateBranch(
string key,
IDataLoader<TKey, TValue> dataLoader,
Expression<Func<TValue, TValue>> selector)
Expression<Func<TElement, TElement>> selector)
{
var branch = new SelectionDataLoader<TKey, TValue>(
(DataLoaderBase<TKey, TValue>)dataLoader,
key);
var context = new DefaultSelectorBuilder<TValue>();
var context = new DefaultSelectorBuilder();
branch.ContextData = branch.ContextData.SetItem(typeof(ISelectorBuilder).FullName!, context);
context.Add(selector);
return branch;
Expand Down Expand Up @@ -134,7 +137,7 @@ public static IDataLoader<TKey, TValue> Include<TKey, TValue>(

var context = dataLoader.GetOrSetState(
typeof(ISelectorBuilder).FullName!,
_ => new DefaultSelectorBuilder<TValue>());
_ => new DefaultSelectorBuilder());
context.Add(Rewrite(includeSelector));
return dataLoader;
}
Expand Down Expand Up @@ -235,7 +238,7 @@ public static IQueryable<KeyValueResult<TKey, IEnumerable<TValue>>> Select<T, TK
{
var selectMethod = _selectMethod.MakeGenericMethod(typeof(TValue), typeof(TValue));

list = Expression.Lambda<Func<T, IEnumerable<TValue>>>(
rewrittenList = Expression.Lambda<Func<T, IEnumerable<TValue>>>(
Expression.Call(
selectMethod,
rewrittenList.Body,
Expand All @@ -254,7 +257,7 @@ public static IQueryable<KeyValueResult<TKey, IEnumerable<TValue>>> Select<T, TK
Expression.Bind(
typeof(KeyValueResult<TKey, IEnumerable<TValue>>).GetProperty(
nameof(KeyValueResult<TKey, IEnumerable<TValue>>.Value))!,
list.Body)),
rewrittenList.Body)),
parameter);

// lastly we apply the selector expression to the queryable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static Expression<Func<TValue, TValue>> AsSelector<TValue>(

if ((flags & FieldFlags.Connection) == FieldFlags.Connection)
{
var builder = new DefaultSelectorBuilder<TValue>();
var builder = new DefaultSelectorBuilder();
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
var count = GetConnectionSelections(selection, buffer);
for (var i = 0; i < count; i++)
Expand All @@ -63,7 +63,7 @@ public static Expression<Func<TValue, TValue>> AsSelector<TValue>(

if ((flags & FieldFlags.CollectionSegment) == FieldFlags.CollectionSegment)
{
var builder = new DefaultSelectorBuilder<TValue>();
var builder = new DefaultSelectorBuilder();
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
var count = GetCollectionSelections(selection, buffer);
for (var i = 0; i < count; i++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,118 @@ public static IDataLoader<TKey, TValue> Select<TKey, TValue>(
return dataLoader.Select(expression);
}

/// <summary>
/// Selects the fields that where selected in the GraphQL selection tree.
/// </summary>
/// <param name="dataLoader">
/// The data loader.
/// </param>
/// <param name="selection">
/// The selection that shall be applied to the data loader.
/// </param>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <returns>
/// Returns a new data loader that applies the selection.
/// </returns>
public static IDataLoader<TKey, TValue[]> Select<TKey, TValue>(
this IDataLoader<TKey, TValue[]> dataLoader,
ISelection selection)
where TKey : notnull
where TValue : notnull
{
var expression = selection.AsSelector<TValue>();
return dataLoader.Select(expression);
}

/// <summary>
/// Selects the fields that where selected in the GraphQL selection tree.
/// </summary>
/// <param name="dataLoader">
/// The data loader.
/// </param>
/// <param name="selection">
/// The selection that shall be applied to the data loader.
/// </param>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <returns>
/// Returns a new data loader that applies the selection.
/// </returns>
public static IDataLoader<TKey, ICollection<TValue>> Select<TKey, TValue>(
this IDataLoader<TKey, ICollection<TValue>> dataLoader,
ISelection selection)
where TKey : notnull
where TValue : notnull
{
var expression = selection.AsSelector<TValue>();
return dataLoader.Select(expression);
}

/// <summary>
/// Selects the fields that where selected in the GraphQL selection tree.
/// </summary>
/// <param name="dataLoader">
/// The data loader.
/// </param>
/// <param name="selection">
/// The selection that shall be applied to the data loader.
/// </param>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <returns>
/// Returns a new data loader that applies the selection.
/// </returns>
public static IDataLoader<TKey, IEnumerable<TValue>> Select<TKey, TValue>(
this IDataLoader<TKey, IEnumerable<TValue>> dataLoader,
ISelection selection)
where TKey : notnull
where TValue : notnull
{
var expression = selection.AsSelector<TValue>();
return dataLoader.Select(expression);
}

/// <summary>
/// Selects the fields that where selected in the GraphQL selection tree.
/// </summary>
/// <param name="dataLoader">
/// The data loader.
/// </param>
/// <param name="selection">
/// The selection that shall be applied to the data loader.
/// </param>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <returns>
/// Returns a new data loader that applies the selection.
/// </returns>
public static IDataLoader<TKey, List<TValue>> Select<TKey, TValue>(
this IDataLoader<TKey, List<TValue>> dataLoader,
ISelection selection)
where TKey : notnull
where TValue : notnull
{
var expression = selection.AsSelector<TValue>();
return dataLoader.Select(expression);
}

/// <summary>
/// Selects the fields that where selected in the GraphQL selection tree.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ public void WriteDataLoaderLoadMethod(
parameter.StateKey);
_writer.IncreaseIndent();
_writer.WriteIndentedLine(
"?? new global::GreenDonut.Projections.DefaultSelectorBuilder<{0}>();",
value.ToFullyQualified());
"?? new global::GreenDonut.Projections.DefaultSelectorBuilder();");
_writer.DecreaseIndent();
}
else if (parameter.Kind is DataLoaderParameterKind.PagingArguments)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,38 @@ public async Task Project_Key_To_Collection_Expression()
.MatchMarkdownSnapshot();
}

[Fact]
public async Task Project_Key_To_Collection_Expression_Integration()
{
// Arrange
var queries = new List<string>();
var connectionString = CreateConnectionString();
await CatalogContext.SeedAsync(connectionString);

// Act
var result = await new ServiceCollection()
.AddScoped(_ => queries)
.AddTransient(_ => new CatalogContext(connectionString))
.AddGraphQLServer()
.AddQueryType<BrandsQuery>()
.AddTypeExtension(typeof(BrandListExtensions))
.ExecuteRequestAsync(
"""
{
brandById(id: 1) {
products {
name
}
}
}
""");
// Assert
Snapshot.Create()
.AddSql(queries)
.Add(result, "Result")
.MatchMarkdownSnapshot();
}

public class Query
{
public async Task<Brand?> GetBrandByIdAsync(
Expand Down Expand Up @@ -792,6 +824,19 @@ public class NodeQuery
}
}

public class BrandsQuery
{
public async Task<IEnumerable<Brand>> GetBrandByIdAsync(
int id,
ISelection selection,
CatalogContext context,
CancellationToken cancellationToken)
=> await context.Brands
.Select(selection.AsSelector<Brand>())
.Take(2)
.ToListAsync(cancellationToken);
}

[ExtendObjectType<Brand>]
public class BrandExtensions
{
Expand All @@ -810,6 +855,18 @@ public string GetDetails(
=> "Brand Name:" + brand.Name;
}

[ExtendObjectType<Brand>]
public class BrandListExtensions
{
[BindMember(nameof(Brand.Products))]
public async Task<IEnumerable<Product>?> GetProductsAsync(
[Parent] Brand brand,
ProductByBrandIdDataLoader2 productByBrandId,
ISelection selection,
CancellationToken cancellationToken)
=> await productByBrandId.Select(selection).LoadAsync(brand.Id, cancellationToken);
}

public class BrandWithRequirementType : ObjectType<Brand>
{
protected override void Configure(IObjectTypeDescriptor<Brand> descriptor)
Expand Down Expand Up @@ -934,7 +991,7 @@ protected override async Task<IReadOnlyDictionary<int, Product[]>> LoadBatchAsyn
CancellationToken cancellationToken)
{
var catalogContext = services.GetRequiredService<CatalogContext>();
var selector = new DefaultSelectorBuilder<Product>();
var selector = new DefaultSelectorBuilder();
selector.Add<Product>(t => new Product { Name = t.Name });

var query = catalogContext.Brands
Expand All @@ -951,6 +1008,35 @@ protected override async Task<IReadOnlyDictionary<int, Product[]>> LoadBatchAsyn
return x;
}
}

public class ProductByBrandIdDataLoader2(
IServiceProvider services,
List<string> queries,
IBatchScheduler batchScheduler,
DataLoaderOptions options)
: StatefulBatchDataLoader<int, Product[]>(batchScheduler, options)
{
protected override async Task<IReadOnlyDictionary<int, Product[]>> LoadBatchAsync(
IReadOnlyList<int> keys,
DataLoaderFetchContext<Product[]> context,
CancellationToken cancellationToken)
{
var catalogContext = services.GetRequiredService<CatalogContext>();

var query = catalogContext.Brands
.Where(t => keys.Contains(t.Id))
.Select(t => t.Id, t => t.Products, context.GetSelector());

lock (queries)
{
queries.Add(query.ToQueryString());
}

var x = await query.ToDictionaryAsync(t => t.Key, t => t.Value.ToArray(), cancellationToken);

return x;
}
}
}

file static class Extensions
Expand Down
Loading

0 comments on commit 6f8272b

Please sign in to comment.