Skip to content

Commit

Permalink
When sorting update commands add Deleted + Inserted number of edges i…
Browse files Browse the repository at this point in the history
…nstead of Deleted * Inserted

Don't allocate a collection for edges if there's only one edge between given two vertices
Reuse the Multigraph instance
Change the benchmark to reuse the context instance to simulate pooling behavior
  • Loading branch information
AndriySvyryd committed Sep 20, 2021
1 parent 901bfb4 commit c6a9f1c
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ public virtual void InitializeFixture()
{
_fixture = CreateFixture();
_fixture.Initialize(0, 1000, 0, 0);
_context = _fixture.CreateContext(disableBatching: Batching);
}

public virtual void InitializeContext()
{
_context = _fixture.CreateContext(disableBatching: Batching);
_transaction = _context.Database.BeginTransaction();
}

Expand All @@ -53,7 +53,7 @@ public virtual void CleanupContext()
}

_transaction.Dispose();
_context.Dispose();
_context.ChangeTracker.Clear();
}

[Benchmark]
Expand Down
138 changes: 99 additions & 39 deletions src/EFCore.Relational/Update/Internal/CommandBatchPreparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class CommandBatchPreparer : ICommandBatchPreparer
{
private readonly int _minBatchSize;
private readonly bool _sensitiveLoggingEnabled;
private readonly Multigraph<IReadOnlyModificationCommand, IAnnotatable> _modificationCommandGraph = new();

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down Expand Up @@ -278,18 +279,18 @@ private void AddUnchangedSharingEntries(
/// </summary>
protected virtual IReadOnlyList<List<IReadOnlyModificationCommand>> TopologicalSort(IEnumerable<IReadOnlyModificationCommand> commands)
{
var modificationCommandGraph = new Multigraph<IReadOnlyModificationCommand, IAnnotatable>();
modificationCommandGraph.AddVertices(commands);
_modificationCommandGraph.Clear();
_modificationCommandGraph.AddVertices(commands);

// The predecessors map allows to populate the graph in linear time
var predecessorsMap = CreateKeyValuePredecessorMap(modificationCommandGraph);
AddForeignKeyEdges(modificationCommandGraph, predecessorsMap);
var predecessorsMap = CreateKeyValuePredecessorMap(_modificationCommandGraph);
AddForeignKeyEdges(_modificationCommandGraph, predecessorsMap);

AddUniqueValueEdges(modificationCommandGraph);
AddUniqueValueEdges(_modificationCommandGraph);

AddSameTableEdges(modificationCommandGraph);
AddSameTableEdges(_modificationCommandGraph);

return modificationCommandGraph.BatchingTopologicalSort(static (_, _, edges) => edges.All(e => e is IEntityType), FormatCycle);
return _modificationCommandGraph.BatchingTopologicalSort(static (_, _, edges) => edges.All(e => e is IEntityType), FormatCycle);
}

private string FormatCycle(IReadOnlyList<Tuple<IReadOnlyModificationCommand, IReadOnlyModificationCommand, IEnumerable<IAnnotatable>>> data)
Expand Down Expand Up @@ -488,12 +489,8 @@ private Dictionary<IKeyValueIndex, List<IReadOnlyModificationCommand>> CreateKey
var entry = command.Entries[i];
foreach (var foreignKey in entry.EntityType.GetReferencingForeignKeys())
{
var constraints = foreignKey.GetMappedConstraints()
.Where(c => c.PrincipalTable.Name == command.TableName && c.PrincipalTable.Schema == command.Schema);

if (!constraints.Any()
|| (entry.EntityState == EntityState.Modified
&& !foreignKey.PrincipalKey.Properties.Any(p => entry.IsModified(p))))
if (!IsMapped(foreignKey, command, principal: true)
|| !IsModified(foreignKey.PrincipalKey.Properties, entry))
{
continue;
}
Expand Down Expand Up @@ -523,12 +520,8 @@ private Dictionary<IKeyValueIndex, List<IReadOnlyModificationCommand>> CreateKey
{
foreach (var foreignKey in entry.EntityType.GetForeignKeys())
{
var constraints = foreignKey.GetMappedConstraints()
.Where(c => c.Table.Name == command.TableName && c.Table.Schema == command.Schema);

if (!constraints.Any()
|| (entry.EntityState == EntityState.Modified
&& !foreignKey.Properties.Any(p => entry.IsModified(p))))
if (!IsMapped(foreignKey, command, principal: false)
|| !IsModified(foreignKey.Properties, entry))
{
continue;
}
Expand Down Expand Up @@ -571,10 +564,8 @@ private void AddForeignKeyEdges(
var entry = command.Entries[entryIndex];
foreach (var foreignKey in entry.EntityType.GetForeignKeys())
{
if (!foreignKey.GetMappedConstraints()
.Any(c => c.Table.Name == command.TableName && c.Table.Schema == command.Schema)
|| (entry.EntityState == EntityState.Modified
&& !foreignKey.Properties.Any(p => entry.IsModified(p))))
if (!IsMapped(foreignKey, command, principal: false)
|| !IsModified(foreignKey.Properties, entry))
{
continue;
}
Expand All @@ -600,9 +591,7 @@ private void AddForeignKeyEdges(
var entry = command.Entries[entryIndex];
foreach (var foreignKey in entry.EntityType.GetReferencingForeignKeys())
{
var constraints = foreignKey.GetMappedConstraints()
.Where(c => c.PrincipalTable.Name == command.TableName && c.PrincipalTable.Schema == command.Schema);
if (!constraints.Any())
if (!IsMapped(foreignKey, command, principal: true))
{
continue;
}
Expand All @@ -623,6 +612,49 @@ private void AddForeignKeyEdges(
}
}

private static bool IsMapped(IForeignKey foreignKey, IReadOnlyModificationCommand command, bool principal)
{
foreach (var constraint in foreignKey.GetMappedConstraints())
{
if (principal)
{
if (constraint.PrincipalTable.Name == command.TableName
&& constraint.PrincipalTable.Schema == command.Schema)
{
return true;
}
}
else
{
if (constraint.Table.Name == command.TableName
&& constraint.Table.Schema == command.Schema)
{
return true;
}
}
}

return false;
}

private static bool IsModified(IReadOnlyList<IProperty> properties, IUpdateEntry entry)
{
if (entry.EntityState != EntityState.Modified)
{
return true;
}

foreach (var property in properties)
{
if (entry.IsModified(property))
{
return true;
}
}

return false;
}

private static void AddMatchingPredecessorEdge<T>(
Dictionary<T, List<IReadOnlyModificationCommand>> predecessorsMap,
T keyValue,
Expand Down Expand Up @@ -658,10 +690,11 @@ private void AddUniqueValueEdges(Multigraph<IReadOnlyModificationCommand, IAnnot
for (var entryIndex = 0; entryIndex < command.Entries.Count; entryIndex++)
{
var entry = command.Entries[entryIndex];
foreach (var index in entry.EntityType.GetIndexes().Where(i => i.IsUnique && i.GetMappedTableIndexes().Any()))
foreach (var index in entry.EntityType.GetIndexes())
{
if (entry.EntityState == EntityState.Modified
&& !index.Properties.Any(p => entry.IsModified(p)))
if (!index.IsUnique
|| !index.GetMappedTableIndexes().Any()
|| !IsModified(index.Properties, entry))
{
continue;
}
Expand All @@ -688,8 +721,13 @@ private void AddUniqueValueEdges(Multigraph<IReadOnlyModificationCommand, IAnnot
continue;
}

foreach (var key in entry.EntityType.GetKeys().Where(k => k.GetMappedConstraints().Any()))
foreach (var key in entry.EntityType.GetKeys())
{
if (!key.GetMappedConstraints().Any())
{
continue;
}

var principalKeyValue = Dependencies.KeyValueIndexFactorySource
.GetKeyValueIndexFactory(key)
.CreatePrincipalKeyValue(entry, null);
Expand Down Expand Up @@ -719,10 +757,11 @@ private void AddUniqueValueEdges(Multigraph<IReadOnlyModificationCommand, IAnnot

foreach (var entry in command.Entries)
{
foreach (var index in entry.EntityType.GetIndexes().Where(i => i.IsUnique && i.GetMappedTableIndexes().Any()))
foreach (var index in entry.EntityType.GetIndexes())
{
if (entry.EntityState == EntityState.Modified
&& !index.Properties.Any(p => entry.IsModified(p)))
if (!index.IsUnique
|| !index.GetMappedTableIndexes().Any()
|| !IsModified(index.Properties, entry))
{
continue;
}
Expand Down Expand Up @@ -751,8 +790,13 @@ private void AddUniqueValueEdges(Multigraph<IReadOnlyModificationCommand, IAnnot

foreach (var entry in command.Entries)
{
foreach (var key in entry.EntityType.GetKeys().Where(k => k.GetMappedConstraints().Any()))
foreach (var key in entry.EntityType.GetKeys())
{
if (!key.GetMappedConstraints().Any())
{
continue;
}

var principalKeyValue = Dependencies.KeyValueIndexFactorySource
.GetKeyValueIndexFactory(key)
.CreatePrincipalKeyValue(entry, null);
Expand All @@ -770,26 +814,42 @@ private void AddUniqueValueEdges(Multigraph<IReadOnlyModificationCommand, IAnnot

private static void AddSameTableEdges(Multigraph<IReadOnlyModificationCommand, IAnnotatable> modificationCommandGraph)
{
var deletedDictionary = new Dictionary<(string, string?), List<IReadOnlyModificationCommand>>();
var deletedDictionary = new Dictionary<(string, string?), (List<IReadOnlyModificationCommand> List, bool EdgesAdded)>();

foreach (var command in modificationCommandGraph.Vertices)
{
if (command.EntityState == EntityState.Deleted)
{
deletedDictionary.GetOrAddNew((command.TableName, command.Schema)).Add(command);
var table = (command.TableName, command.Schema);
if (!deletedDictionary.TryGetValue(table, out var deletedCommands))
{
deletedCommands = (new List<IReadOnlyModificationCommand>(), false);
deletedDictionary.Add(table, deletedCommands);
}
deletedCommands.List.Add(command);
}
}

foreach (var command in modificationCommandGraph.Vertices)
{
if (command.EntityState == EntityState.Added)
{
if (deletedDictionary.TryGetValue((command.TableName, command.Schema), out var deletedList))
var table = (command.TableName, command.Schema);
if (deletedDictionary.TryGetValue(table, out var deletedCommands))
{
foreach (var deleted in deletedList)
var lastDelete = deletedCommands.List[^1];
if (!deletedCommands.EdgesAdded)
{
modificationCommandGraph.AddEdge(deleted, command, command.Entries[0].EntityType);
for (var i = 0; i < deletedCommands.List.Count - 1; i++)
{
var deleted = deletedCommands.List[i];
modificationCommandGraph.AddEdge(deleted, lastDelete, deleted.Entries[0].EntityType);
}

deletedDictionary[table] = (deletedCommands.List, true);
}

modificationCommandGraph.AddEdge(lastDelete, command, command.Entries[0].EntityType);
}
}
}
Expand Down
50 changes: 30 additions & 20 deletions src/Shared/Multigraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,16 @@ internal class Multigraph<TVertex, TEdge> : Graph<TVertex>
where TVertex : notnull
{
private readonly HashSet<TVertex> _vertices = new();
private readonly Dictionary<TVertex, Dictionary<TVertex, List<TEdge>>> _successorMap = new();
private readonly Dictionary<TVertex, Dictionary<TVertex, object?>> _successorMap = new();
private readonly Dictionary<TVertex, HashSet<TVertex>> _predecessorMap = new();

public IEnumerable<TEdge> Edges
=> _successorMap.Values.SelectMany(s => s.Values).SelectMany(e => e).Distinct();

public IEnumerable<TEdge> GetEdges(TVertex from, TVertex to)
{
if (_successorMap.TryGetValue(from, out var successorSet))
{
if (successorSet.TryGetValue(to, out var edgeList))
if (successorSet.TryGetValue(to, out var edges))
{
return edgeList;
return edges is IEnumerable<TEdge> edgeList ? edgeList : (new TEdge[] { (TEdge)edges! });
}
}

Expand Down Expand Up @@ -55,17 +52,23 @@ public void AddEdge(TVertex from, TVertex to, TEdge edge)

if (!_successorMap.TryGetValue(from, out var successorEdges))
{
successorEdges = new Dictionary<TVertex, List<TEdge>>();
successorEdges = new Dictionary<TVertex, object?>();
_successorMap.Add(from, successorEdges);
}

if (!successorEdges.TryGetValue(to, out var edgeList))
if (successorEdges.TryGetValue(to, out var edges))
{
edgeList = new List<TEdge>();
successorEdges.Add(to, edgeList);
if (edges is not List<TEdge> edgeList)
{
edgeList = new List<TEdge> { (TEdge)edges! };
successorEdges[to] = edgeList;
}
edgeList.Add(edge);
}
else
{
successorEdges.Add(to, edge);
}

edgeList.Add(edge);

if (!_predecessorMap.TryGetValue(to, out var predecessors))
{
Expand All @@ -76,7 +79,7 @@ public void AddEdge(TVertex from, TVertex to, TEdge edge)
predecessors.Add(from);
}

public void AddEdges(TVertex from, TVertex to, IEnumerable<TEdge> edges)
public void AddEdges(TVertex from, TVertex to, IEnumerable<TEdge> newEdges)
{
#if DEBUG
if (!_vertices.Contains(from))
Expand All @@ -92,18 +95,25 @@ public void AddEdges(TVertex from, TVertex to, IEnumerable<TEdge> edges)

if (!_successorMap.TryGetValue(from, out var successorEdges))
{
successorEdges = new Dictionary<TVertex, List<TEdge>>();
successorEdges = new Dictionary<TVertex, object?>();
_successorMap.Add(from, successorEdges);
}

if (!successorEdges.TryGetValue(to, out var edgeList))
if (successorEdges.TryGetValue(to, out var edges))
{
edgeList = new List<TEdge>();
if (edges is not List<TEdge> edgeList)
{
edgeList = new List<TEdge> { (TEdge)edges! };
successorEdges[to] = edgeList;
}
edgeList.AddRange(newEdges);
}
else
{
var edgeList = newEdges.ToList();
successorEdges.Add(to, edgeList);
}

edgeList.AddRange(edges);

if (!_predecessorMap.TryGetValue(to, out var predecessors))
{
predecessors = new HashSet<TVertex>();
Expand Down Expand Up @@ -193,7 +203,7 @@ public IReadOnlyList<TVertex> TopologicalSort(
.First(neighbor => predecessorCounts.TryGetValue(neighbor, out var neighborPredecessors)
&& neighborPredecessors > 0);

if (tryBreakEdge(incomingNeighbor, candidateVertex, _successorMap[incomingNeighbor][candidateVertex]))
if (tryBreakEdge(incomingNeighbor, candidateVertex, GetEdges(incomingNeighbor, candidateVertex)))
{
_successorMap[incomingNeighbor].Remove(candidateVertex);
_predecessorMap[candidateVertex].Remove(incomingNeighbor);
Expand Down Expand Up @@ -367,7 +377,7 @@ public IReadOnlyList<List<TVertex>> BatchingTopologicalSort(
.First(neighbor => predecessorCounts.TryGetValue(neighbor, out var neighborPredecessors)
&& neighborPredecessors > 0);

if (tryBreakEdge(incomingNeighbor, candidateVertex, _successorMap[incomingNeighbor][candidateVertex]))
if (tryBreakEdge(incomingNeighbor, candidateVertex, GetEdges(incomingNeighbor, candidateVertex)))
{
_successorMap[incomingNeighbor].Remove(candidateVertex);
_predecessorMap[candidateVertex].Remove(incomingNeighbor);
Expand Down
3 changes: 0 additions & 3 deletions test/EFCore.Tests/Utilities/MultigraphTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ public void AddEdge_adds_an_edge()
graph.AddEdge(vertexOne, vertexTwo, edgeOne);
graph.AddEdge(vertexOne, vertexTwo, edgeTwo);

Assert.Equal(2, graph.Edges.Count());
Assert.Equal(2, graph.Edges.Intersect(new[] { edgeOne, edgeTwo }).Count());

Assert.Empty(graph.GetEdges(vertexTwo, vertexOne));
Assert.Equal(2, graph.GetEdges(vertexOne, vertexTwo).Count());
Assert.Equal(2, graph.GetEdges(vertexOne, vertexTwo).Intersect(new[] { edgeOne, edgeTwo }).Count());
Expand Down

0 comments on commit c6a9f1c

Please sign in to comment.