Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added more overloads for list selects. #7578

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/GreenDonut/src/Core/DataLoaderFetchContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,6 @@ public ISelectorBuilder GetSelector()

// if no selector was found we will just return
// a new default selector builder.
return new DefaultSelectorBuilder<TValue>();
return new DefaultSelectorBuilder();
}
}
15 changes: 1 addition & 14 deletions src/GreenDonut/src/Core/Projections/DefaultSelectorBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,14 @@ namespace GreenDonut.Projections;
/// <summary>
/// A default implementation of the <see cref="ISelectorBuilder"/>.
/// </summary>
/// <typeparam name="TValue"></typeparam>
[Experimental(Experiments.Projections)]
public sealed class DefaultSelectorBuilder<TValue> : ISelectorBuilder
public sealed class DefaultSelectorBuilder : ISelectorBuilder
{
private List<LambdaExpression>? _selectors;

/// <inheritdoc />
public void Add<T>(Expression<Func<T, T>> selector)
{
if (typeof(T) != typeof(TValue))
{
throw new ArgumentException(
"The projection type must match the DataLoader value type.",
nameof(selector));
}

_selectors ??= new List<LambdaExpression>();
if (!_selectors.Contains(selector))
{
Expand All @@ -37,11 +29,6 @@ public void Add<T>(Expression<Func<T, T>> selector)
return null;
}

if (typeof(T) != typeof(TValue))
{
return null;
}

if (_selectors.Count == 1)
{
return (Expression<Func<T, T>>)_selectors[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,18 @@ public static class SelectionDataLoaderExtensions
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <typeparam name="TElement">
/// The element type.
/// </typeparam>
/// <returns>
/// Returns a branched DataLoader with the selector applied.
/// </returns>
/// <exception cref="ArgumentNullException">
/// Throws if <paramref name="dataLoader"/> is <c>null</c>.
/// </exception>
public static IDataLoader<TKey, TValue> Select<TKey, TValue>(
public static IDataLoader<TKey, TValue> Select<TKey, TValue, TElement>(
this IDataLoader<TKey, TValue> dataLoader,
Expression<Func<TValue, TValue>>? selector)
Expression<Func<TElement, TElement>>? selector)
where TKey : notnull
{
if (dataLoader is null)
Expand All @@ -55,7 +58,7 @@ public static IDataLoader<TKey, TValue> Select<TKey, TValue>(

if (dataLoader is ISelectionDataLoader<TKey, TValue>)
{
var context = (DefaultSelectorBuilder<TValue>)dataLoader.ContextData[typeof(ISelectorBuilder).FullName!]!;
var context = (DefaultSelectorBuilder)dataLoader.ContextData[typeof(ISelectorBuilder).FullName!]!;
context.Add(selector);
return dataLoader;
}
Expand All @@ -66,12 +69,12 @@ public static IDataLoader<TKey, TValue> Select<TKey, TValue>(
static IDataLoader CreateBranch(
string key,
IDataLoader<TKey, TValue> dataLoader,
Expression<Func<TValue, TValue>> selector)
Expression<Func<TElement, TElement>> selector)
{
var branch = new SelectionDataLoader<TKey, TValue>(
(DataLoaderBase<TKey, TValue>)dataLoader,
key);
var context = new DefaultSelectorBuilder<TValue>();
var context = new DefaultSelectorBuilder();
branch.ContextData = branch.ContextData.SetItem(typeof(ISelectorBuilder).FullName!, context);
context.Add(selector);
return branch;
Expand Down Expand Up @@ -134,7 +137,7 @@ public static IDataLoader<TKey, TValue> Include<TKey, TValue>(

var context = dataLoader.GetOrSetState(
typeof(ISelectorBuilder).FullName!,
_ => new DefaultSelectorBuilder<TValue>());
_ => new DefaultSelectorBuilder());
context.Add(Rewrite(includeSelector));
return dataLoader;
}
Expand Down Expand Up @@ -235,7 +238,7 @@ public static IQueryable<KeyValueResult<TKey, IEnumerable<TValue>>> Select<T, TK
{
var selectMethod = _selectMethod.MakeGenericMethod(typeof(TValue), typeof(TValue));

list = Expression.Lambda<Func<T, IEnumerable<TValue>>>(
rewrittenList = Expression.Lambda<Func<T, IEnumerable<TValue>>>(
Expression.Call(
selectMethod,
rewrittenList.Body,
Expand All @@ -254,7 +257,7 @@ public static IQueryable<KeyValueResult<TKey, IEnumerable<TValue>>> Select<T, TK
Expression.Bind(
typeof(KeyValueResult<TKey, IEnumerable<TValue>>).GetProperty(
nameof(KeyValueResult<TKey, IEnumerable<TValue>>.Value))!,
list.Body)),
rewrittenList.Body)),
parameter);

// lastly we apply the selector expression to the queryable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static Expression<Func<TValue, TValue>> AsSelector<TValue>(

if ((flags & FieldFlags.Connection) == FieldFlags.Connection)
{
var builder = new DefaultSelectorBuilder<TValue>();
var builder = new DefaultSelectorBuilder();
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
var count = GetConnectionSelections(selection, buffer);
for (var i = 0; i < count; i++)
Expand All @@ -63,7 +63,7 @@ public static Expression<Func<TValue, TValue>> AsSelector<TValue>(

if ((flags & FieldFlags.CollectionSegment) == FieldFlags.CollectionSegment)
{
var builder = new DefaultSelectorBuilder<TValue>();
var builder = new DefaultSelectorBuilder();
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
var count = GetCollectionSelections(selection, buffer);
for (var i = 0; i < count; i++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,118 @@ public static IDataLoader<TKey, TValue> Select<TKey, TValue>(
return dataLoader.Select(expression);
}

/// <summary>
/// Selects the fields that where selected in the GraphQL selection tree.
/// </summary>
/// <param name="dataLoader">
/// The data loader.
/// </param>
/// <param name="selection">
/// The selection that shall be applied to the data loader.
/// </param>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <returns>
/// Returns a new data loader that applies the selection.
/// </returns>
public static IDataLoader<TKey, TValue[]> Select<TKey, TValue>(
this IDataLoader<TKey, TValue[]> dataLoader,
ISelection selection)
where TKey : notnull
where TValue : notnull
{
var expression = selection.AsSelector<TValue>();
return dataLoader.Select(expression);
}

/// <summary>
/// Selects the fields that where selected in the GraphQL selection tree.
/// </summary>
/// <param name="dataLoader">
/// The data loader.
/// </param>
/// <param name="selection">
/// The selection that shall be applied to the data loader.
/// </param>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <returns>
/// Returns a new data loader that applies the selection.
/// </returns>
public static IDataLoader<TKey, ICollection<TValue>> Select<TKey, TValue>(
this IDataLoader<TKey, ICollection<TValue>> dataLoader,
ISelection selection)
where TKey : notnull
where TValue : notnull
{
var expression = selection.AsSelector<TValue>();
return dataLoader.Select(expression);
}

/// <summary>
/// Selects the fields that where selected in the GraphQL selection tree.
/// </summary>
/// <param name="dataLoader">
/// The data loader.
/// </param>
/// <param name="selection">
/// The selection that shall be applied to the data loader.
/// </param>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <returns>
/// Returns a new data loader that applies the selection.
/// </returns>
public static IDataLoader<TKey, IEnumerable<TValue>> Select<TKey, TValue>(
this IDataLoader<TKey, IEnumerable<TValue>> dataLoader,
ISelection selection)
where TKey : notnull
where TValue : notnull
{
var expression = selection.AsSelector<TValue>();
return dataLoader.Select(expression);
}

/// <summary>
/// Selects the fields that where selected in the GraphQL selection tree.
/// </summary>
/// <param name="dataLoader">
/// The data loader.
/// </param>
/// <param name="selection">
/// The selection that shall be applied to the data loader.
/// </param>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <returns>
/// Returns a new data loader that applies the selection.
/// </returns>
public static IDataLoader<TKey, List<TValue>> Select<TKey, TValue>(
this IDataLoader<TKey, List<TValue>> dataLoader,
ISelection selection)
where TKey : notnull
where TValue : notnull
{
var expression = selection.AsSelector<TValue>();
return dataLoader.Select(expression);
}

/// <summary>
/// Selects the fields that where selected in the GraphQL selection tree.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ public void WriteDataLoaderLoadMethod(
parameter.StateKey);
_writer.IncreaseIndent();
_writer.WriteIndentedLine(
"?? new global::GreenDonut.Projections.DefaultSelectorBuilder<{0}>();",
value.ToFullyQualified());
"?? new global::GreenDonut.Projections.DefaultSelectorBuilder();");
_writer.DecreaseIndent();
}
else if (parameter.Kind is DataLoaderParameterKind.PagingArguments)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,38 @@ public async Task Project_Key_To_Collection_Expression()
.MatchMarkdownSnapshot();
}

[Fact]
public async Task Project_Key_To_Collection_Expression_Integration()
{
// Arrange
var queries = new List<string>();
var connectionString = CreateConnectionString();
await CatalogContext.SeedAsync(connectionString);

// Act
var result = await new ServiceCollection()
.AddScoped(_ => queries)
.AddTransient(_ => new CatalogContext(connectionString))
.AddGraphQLServer()
.AddQueryType<BrandsQuery>()
.AddTypeExtension(typeof(BrandListExtensions))
.ExecuteRequestAsync(
"""
{
brandById(id: 1) {
products {
name
}
}
}
""");
// Assert
Snapshot.Create()
.AddSql(queries)
.Add(result, "Result")
.MatchMarkdownSnapshot();
}

public class Query
{
public async Task<Brand?> GetBrandByIdAsync(
Expand Down Expand Up @@ -792,6 +824,19 @@ public class NodeQuery
}
}

public class BrandsQuery
{
public async Task<IEnumerable<Brand>> GetBrandByIdAsync(
int id,
ISelection selection,
CatalogContext context,
CancellationToken cancellationToken)
=> await context.Brands
.Select(selection.AsSelector<Brand>())
.Take(2)
.ToListAsync(cancellationToken);
}

[ExtendObjectType<Brand>]
public class BrandExtensions
{
Expand All @@ -810,6 +855,18 @@ public string GetDetails(
=> "Brand Name:" + brand.Name;
}

[ExtendObjectType<Brand>]
public class BrandListExtensions
{
[BindMember(nameof(Brand.Products))]
public async Task<IEnumerable<Product>?> GetProductsAsync(
[Parent] Brand brand,
ProductByBrandIdDataLoader2 productByBrandId,
ISelection selection,
CancellationToken cancellationToken)
=> await productByBrandId.Select(selection).LoadAsync(brand.Id, cancellationToken);
}

public class BrandWithRequirementType : ObjectType<Brand>
{
protected override void Configure(IObjectTypeDescriptor<Brand> descriptor)
Expand Down Expand Up @@ -934,7 +991,7 @@ protected override async Task<IReadOnlyDictionary<int, Product[]>> LoadBatchAsyn
CancellationToken cancellationToken)
{
var catalogContext = services.GetRequiredService<CatalogContext>();
var selector = new DefaultSelectorBuilder<Product>();
var selector = new DefaultSelectorBuilder();
selector.Add<Product>(t => new Product { Name = t.Name });

var query = catalogContext.Brands
Expand All @@ -951,6 +1008,35 @@ protected override async Task<IReadOnlyDictionary<int, Product[]>> LoadBatchAsyn
return x;
}
}

public class ProductByBrandIdDataLoader2(
IServiceProvider services,
List<string> queries,
IBatchScheduler batchScheduler,
DataLoaderOptions options)
: StatefulBatchDataLoader<int, Product[]>(batchScheduler, options)
{
protected override async Task<IReadOnlyDictionary<int, Product[]>> LoadBatchAsync(
IReadOnlyList<int> keys,
DataLoaderFetchContext<Product[]> context,
CancellationToken cancellationToken)
{
var catalogContext = services.GetRequiredService<CatalogContext>();

var query = catalogContext.Brands
.Where(t => keys.Contains(t.Id))
.Select(t => t.Id, t => t.Products, context.GetSelector());

lock (queries)
{
queries.Add(query.ToQueryString());
}

var x = await query.ToDictionaryAsync(t => t.Key, t => t.Value.ToArray(), cancellationToken);

return x;
}
}
}

file static class Extensions
Expand Down
Loading
Loading