Skip to content

Commit

Permalink
Allow UseTransaction to replace existing transaction
Browse files Browse the repository at this point in the history
Fixes #25946
  • Loading branch information
ajcvickers committed Sep 12, 2021
1 parent 8098a16 commit 3d2ebf9
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 13 deletions.
23 changes: 22 additions & 1 deletion src/EFCore.Relational/Storage/RelationalConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,11 @@ private void EnsureNoTransactions()
throw new InvalidOperationException(RelationalStrings.TransactionAlreadyStarted);
}

EnsureNoAmbientOrEnlistedTransactions();
}

private void EnsureNoAmbientOrEnlistedTransactions()
{
if (CurrentAmbientTransaction != null)
{
throw new InvalidOperationException(RelationalStrings.ConflictingAmbientTransaction);
Expand Down Expand Up @@ -467,6 +472,11 @@ private IDbContextTransaction CreateRelationalTransaction(DbTransaction transact
{
if (ShouldUseTransaction(transaction))
{
if (CurrentTransaction != null)
{
CurrentTransaction.Dispose();
}

Open();

transaction = Dependencies.TransactionLogger.TransactionUsed(
Expand Down Expand Up @@ -508,6 +518,11 @@ private IDbContextTransaction CreateRelationalTransaction(DbTransaction transact
{
if (ShouldUseTransaction(transaction))
{
if (CurrentTransaction != null)
{
await CurrentTransaction.DisposeAsync();
}

await OpenAsync(cancellationToken).ConfigureAwait(false);

transaction = await Dependencies.TransactionLogger.TransactionUsedAsync(
Expand Down Expand Up @@ -536,7 +551,13 @@ private bool ShouldUseTransaction([NotNullWhen(true)] DbTransaction? transaction
return false;
}

EnsureNoTransactions();
EnsureNoAmbientOrEnlistedTransactions();

if (CurrentTransaction != null
&& transaction == CurrentTransaction.GetDbTransaction())
{
return false;
}

return true;
}
Expand Down
44 changes: 32 additions & 12 deletions test/EFCore.Relational.Specification.Tests/TransactionTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -935,21 +935,41 @@ public virtual void UseTransaction_throws_if_mismatched_connection()
Assert.Equal(RelationalStrings.TransactionAssociatedWithDifferentConnection, ex.Message);
}

[ConditionalFact]
public virtual void UseTransaction_throws_if_another_transaction_started()
[ConditionalTheory]
[InlineData(true)]
[InlineData(false)]
public virtual async Task UseTransaction_is_no_op_if_same_DbTransaction_is_used(bool async)
{
using var transaction = TestStore.BeginTransaction();
using var context = CreateContextWithConnectionString();
using (context.Database.BeginTransaction(
DirtyReadsOccur
? IsolationLevel.ReadUncommitted
: IsolationLevel.Unspecified))
using (var transaction = TestStore.BeginTransaction())
{
var ex = Assert.Throws<InvalidOperationException>(
() =>
context.Database.UseTransaction(transaction));
Assert.Equal(RelationalStrings.TransactionAlreadyStarted, ex.Message);
using var context = CreateContext();

var currentTransaction = async
? await context.Database.UseTransactionAsync(transaction)
: context.Database.UseTransaction(transaction);

Assert.Same(transaction, currentTransaction!.GetDbTransaction());

var newTransaction = async
? await context.Database.UseTransactionAsync(transaction)
: context.Database.UseTransaction(transaction);

Assert.Same(currentTransaction, newTransaction);
Assert.Same(transaction, newTransaction!.GetDbTransaction());

context.Entry(context.Set<TransactionCustomer>().OrderBy(c => c.Id).First()).State = EntityState.Deleted;

if (async)
{
await context.SaveChangesAsync();
}
else
{
context.SaveChanges();
}
}

AssertStoreInitialState();
}

[ConditionalFact]
Expand Down
186 changes: 186 additions & 0 deletions test/EFCore.SqlServer.FunctionalTests/ExecutionStrategyTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Diagnostics;
Expand Down Expand Up @@ -381,6 +382,177 @@ public async Task Retries_SaveChanges_on_execution_failure(
}
}

[ConditionalTheory] // Issue #25946
[MemberData(nameof(DataGenerator.GetBoolCombinations), 3, MemberType = typeof(DataGenerator))]
public async Task Retries_SaveChanges_on_execution_failure_with_two_contexts(
bool realFailure,
bool openConnection,
bool async)
{
CleanContext();
using (var context = CreateContext())
{
using var auditContext = new AuditContext();
var connection = (TestSqlServerConnection)context.GetService<ISqlServerConnection>();
connection.ExecutionFailures.Enqueue(new bool?[] { null, realFailure });
Assert.Equal(ConnectionState.Closed, context.Database.GetDbConnection().State);
if (openConnection)
{
if (async)
{
await context.Database.OpenConnectionAsync();
}
else
{
context.Database.OpenConnection();
}
Assert.Equal(ConnectionState.Open, context.Database.GetDbConnection().State);
}
context.Products.Add(new Product());
context.Products.Add(new Product());
var throwTransientError = true;
if (async)
{
await new TestSqlServerRetryingExecutionStrategy(context).ExecuteInTransactionAsync(
(MainContext: context, AuditContext: auditContext),
async (c, ct) =>
{
var result = await c.MainContext.SaveChangesAsync(false, ct);
c.AuditContext.ChangeTracker.Clear();
c.AuditContext.Database.SetDbConnection(c.MainContext.Database.GetDbConnection());
var currentTransaction = c.AuditContext.Database.CurrentTransaction;
if (throwTransientError)
{
Assert.Null(currentTransaction);
}
else
{
Assert.NotNull(currentTransaction);
}
await c.AuditContext.Database.UseTransactionAsync(
c.MainContext.Database.CurrentTransaction!.GetDbTransaction(), ct);
Assert.NotSame(currentTransaction, c.AuditContext.Database.CurrentTransaction);
if (currentTransaction != null)
{
Assert.True(
(bool)typeof(RelationalTransaction).GetRuntimeFields().Single(f => f.Name == "_disposed")
.GetValue(currentTransaction)!);
}
await c.AuditContext.Audits.AddAsync(new Audit(), ct);
await c.AuditContext.SaveChangesAsync(ct);
if (throwTransientError)
{
throwTransientError = false;
throw SqlExceptionFactory.CreateSqlException(10928);
}
return result;
},
(c, _) =>
{
Assert.True(false);
return Task.FromResult(false);
});

context.ChangeTracker.AcceptAllChanges();
}
else
{
new TestSqlServerRetryingExecutionStrategy(context).ExecuteInTransaction(
(MainContext: context, AuditContext: auditContext),
c =>
{
var result = c.MainContext.SaveChanges(false);
c.AuditContext.ChangeTracker.Clear();
c.AuditContext.Database.SetDbConnection(c.MainContext.Database.GetDbConnection());
var currentTransaction = c.AuditContext.Database.CurrentTransaction;
if (throwTransientError)
{
Assert.Null(currentTransaction);
}
else
{
Assert.NotNull(currentTransaction);
}
c.AuditContext.Database.UseTransaction(c.MainContext.Database.CurrentTransaction!.GetDbTransaction());
Assert.NotSame(currentTransaction, c.AuditContext.Database.CurrentTransaction);
if (currentTransaction != null)
{
Assert.True(
(bool)typeof(RelationalTransaction).GetRuntimeFields().Single(f => f.Name == "_disposed")
.GetValue(currentTransaction)!);
}
c.AuditContext.Audits.Add(new Audit());
c.AuditContext.SaveChanges();
if (throwTransientError)
{
throwTransientError = false;
throw SqlExceptionFactory.CreateSqlException(10928);
}
return result;
},
c =>
{
Assert.True(false);
return false;
});

context.ChangeTracker.AcceptAllChanges();
}

Assert.Equal(openConnection ? 2 : 3, connection.OpenCount);
Assert.Equal(6, connection.ExecutionCount);

Assert.Equal(
openConnection
? ConnectionState.Open
: ConnectionState.Closed, context.Database.GetDbConnection().State);

if (openConnection)
{
if (async)
{
await context.Database.CloseConnectionAsync();
}
else
{
context.Database.CloseConnection();
}
}

Assert.Equal(ConnectionState.Closed, context.Database.GetDbConnection().State);
}

using (var context = CreateContext())
{
Assert.Equal(2, context.Products.Count());
}
}

[ConditionalTheory]
[MemberData(nameof(DataGenerator.GetBoolCombinations), 2, MemberType = typeof(DataGenerator))]
public async Task Retries_query_on_execution_failure(bool externalStrategy, bool async)
Expand Down Expand Up @@ -629,6 +801,7 @@ public ExecutionStrategyContext(DbContextOptions options)
}

public DbSet<Product> Products { get; set; }
public DbSet<Audit> Audits { get; set; }
}

protected class Product
Expand All @@ -637,6 +810,19 @@ protected class Product
public string Name { get; set; }
}

public class AuditContext : DbContext
{
public DbSet<Audit> Audits { get; set; }

protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
=> optionsBuilder.UseSqlServer();
}

public class Audit
{
public int AuditId { get; set; }
}

protected virtual ExecutionStrategyContext CreateContext()
=> (ExecutionStrategyContext)Fixture.CreateContext();

Expand Down

0 comments on commit 3d2ebf9

Please sign in to comment.