Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ private static readonly MethodInfo CollectionAccessorAddMethodInfo
= typeof(IClrCollectionAccessor).GetTypeInfo()
.GetDeclaredMethod(nameof(IClrCollectionAccessor.Add));

private static readonly MethodInfo CollectionAccessorGetOrCreateMethodInfo
= typeof(IClrCollectionAccessor).GetTypeInfo()
.GetDeclaredMethod(nameof(IClrCollectionAccessor.GetOrCreate));

private readonly IDictionary<ParameterExpression, Expression> _materializationContextBindings
= new Dictionary<ParameterExpression, Expression>();

Expand Down Expand Up @@ -402,7 +398,6 @@ private void AddInclude(
var inverseNavigation = navigation.Inverse;
var fixup = GenerateFixup(
includingClrType, relatedEntityClrType, navigation, inverseNavigation);
var initialize = GenerateInitialize(includingClrType, navigation);

var navigationExpression = Visit(includeExpression.NavigationExpression);

Expand All @@ -421,7 +416,6 @@ private void AddInclude(
Constant(navigation),
Constant(inverseNavigation, typeof(INavigation)),
Constant(fixup),
Constant(initialize, typeof(Action<>).MakeGenericType(includingClrType)),
#pragma warning disable EF1001 // Internal EF Core API usage.
Constant(includeExpression.SetLoaded))));
#pragma warning restore EF1001 // Internal EF Core API usage.
Expand All @@ -441,8 +435,7 @@ private static void IncludeReference<TIncludingEntity, TIncludedEntity>(
INavigation navigation,
INavigation inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
Action<TIncludingEntity> _,
bool __)
bool _)
{
if (entity == null
|| !navigation.DeclaringEntityType.IsAssignableFrom(entityType))
Expand Down Expand Up @@ -486,7 +479,6 @@ private static void IncludeCollection<TIncludingEntity, TIncludedEntity>(
INavigation navigation,
INavigation inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
Action<TIncludingEntity> initialize,
bool setLoaded)
{
if (entity == null
Expand All @@ -495,6 +487,8 @@ private static void IncludeCollection<TIncludingEntity, TIncludedEntity>(
return;
}

navigation.GetCollectionAccessor()!.GetOrCreate(entity, forMaterialization: true);

if (entry == null)
{
var includingEntity = (TIncludingEntity)entity;
Expand All @@ -508,10 +502,6 @@ private static void IncludeCollection<TIncludingEntity, TIncludedEntity>(
inverseNavigation?.SetIsLoadedWhenNoTracking(relatedEntity);
}
}
else
{
initialize(includingEntity);
}
}
else
{
Expand All @@ -524,15 +514,12 @@ private static void IncludeCollection<TIncludingEntity, TIncludedEntity>(

if (relatedEntities != null)
{
// Enumerator contains logic for tracking the entities, so we need to make sure to enumerate it
using var enumerator = relatedEntities.GetEnumerator();
while (enumerator.MoveNext())
{
}
}
else
{
initialize((TIncludingEntity)entity);
}
}
}

Expand Down Expand Up @@ -563,27 +550,6 @@ private static Delegate GenerateFixup(
.Compile();
}

private static Delegate GenerateInitialize(
Type entityType,
INavigation navigation)
{
if (!navigation.IsCollection)
{
return null;
}

var entityParameter = Parameter(entityType);

var getOrCreateExpression = Call(
Constant(navigation.GetCollectionAccessor()),
CollectionAccessorGetOrCreateMethodInfo,
entityParameter,
Constant(true));

return Lambda(Block(typeof(void), getOrCreateExpression), entityParameter)
.Compile();
}

private static Expression AssignReferenceNavigation(
ParameterExpression entity,
ParameterExpression relatedEntity,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,82 +31,4 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
.ToContainer("RootEntities")
.HasNoDiscriminator();
}

// We need to override the following asserters because of #36577:
// the Cosmos provider incorrectly returns null for empty collections in some cases
protected override void AssertRootEntity(RootEntity e, RootEntity a)
{
Assert.Equal(e.Id, a.Id);
Assert.Equal(e.Name, a.Name);

NullSafeAssert<AssociateType>(e.RequiredAssociate, a.RequiredAssociate, AssertAssociate);
NullSafeAssert<AssociateType>(e.OptionalAssociate, a.OptionalAssociate, AssertAssociate);

if (e.AssociateCollection is not null && a.AssociateCollection is not null)
{
Assert.Equal(e.AssociateCollection.Count, a.AssociateCollection.Count);

var (orderedExpected, orderedActual) = (e.AssociateCollection, a.AssociateCollection);

for (var i = 0; i < e.AssociateCollection.Count; i++)
{
AssertAssociate(orderedExpected[i], orderedActual[i]);
}
}
else
{
// #36577: the Cosmos provider incorrectly returns null for empty collections in some cases
if (e.AssociateCollection is [] && a.AssociateCollection is null)
{
return;
}

Assert.Equal(e.AssociateCollection, a.AssociateCollection);
}
}

protected override void AssertAssociate(AssociateType e, AssociateType a)
{
Assert.Equal(e.Id, a.Id);
Assert.Equal(e.Name, a.Name);

Assert.Equal(e.Int, a.Int);
Assert.Equal(e.String, a.String);

NullSafeAssert<NestedAssociateType>(e.RequiredNestedAssociate, a.RequiredNestedAssociate, AssertNestedAssociate);
NullSafeAssert<NestedAssociateType>(e.OptionalNestedAssociate, a.OptionalNestedAssociate, AssertNestedAssociate);

if (e.NestedCollection is not null && a.NestedCollection != null)
{
Assert.Equal(e.NestedCollection.Count, a.NestedCollection.Count);

var (orderedExpected, orderedActual) = (e.NestedCollection, a.NestedCollection);

for (var i = 0; i < e.NestedCollection.Count; i++)
{
AssertNestedAssociate(orderedExpected[i], orderedActual[i]);
}
}
else
{
// #36577: the Cosmos provider incorrectly returns null for empty collections in some cases
if (e.NestedCollection is [] && a.NestedCollection is null)
{
return;
}

Assert.Equal(e.NestedCollection, a.NestedCollection);
}
}

private static void NullSafeAssert<T>(object? e, object? a, Action<T, T> assertAction)
{
if (e is T ee && a is T aa)
{
assertAction(ee, aa);
return;
}

Assert.Equal(e, a);
}
}
Loading