diff --git a/src/EFCore.Relational/Storage/RelationalConnection.cs b/src/EFCore.Relational/Storage/RelationalConnection.cs index e3f742816a1..4a799f84d5d 100644 --- a/src/EFCore.Relational/Storage/RelationalConnection.cs +++ b/src/EFCore.Relational/Storage/RelationalConnection.cs @@ -425,6 +425,11 @@ private void EnsureNoTransactions() throw new InvalidOperationException(RelationalStrings.TransactionAlreadyStarted); } + EnsureNoAmbientOrEnlistedTransactions(); + } + + private void EnsureNoAmbientOrEnlistedTransactions() + { if (CurrentAmbientTransaction != null) { throw new InvalidOperationException(RelationalStrings.ConflictingAmbientTransaction); @@ -467,6 +472,11 @@ private IDbContextTransaction CreateRelationalTransaction(DbTransaction transact { if (ShouldUseTransaction(transaction)) { + if (CurrentTransaction != null) + { + CurrentTransaction.Dispose(); + } + Open(); transaction = Dependencies.TransactionLogger.TransactionUsed( @@ -508,6 +518,11 @@ private IDbContextTransaction CreateRelationalTransaction(DbTransaction transact { if (ShouldUseTransaction(transaction)) { + if (CurrentTransaction != null) + { + await CurrentTransaction.DisposeAsync(); + } + await OpenAsync(cancellationToken).ConfigureAwait(false); transaction = await Dependencies.TransactionLogger.TransactionUsedAsync( @@ -536,7 +551,13 @@ private bool ShouldUseTransaction([NotNullWhen(true)] DbTransaction? transaction return false; } - EnsureNoTransactions(); + EnsureNoAmbientOrEnlistedTransactions(); + + if (CurrentTransaction != null + && transaction == CurrentTransaction.GetDbTransaction()) + { + return false; + } return true; } diff --git a/test/EFCore.Relational.Specification.Tests/TransactionTestBase.cs b/test/EFCore.Relational.Specification.Tests/TransactionTestBase.cs index e1c43e6f11e..c66ade4fdfc 100644 --- a/test/EFCore.Relational.Specification.Tests/TransactionTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/TransactionTestBase.cs @@ -935,21 +935,41 @@ public virtual void UseTransaction_throws_if_mismatched_connection() Assert.Equal(RelationalStrings.TransactionAssociatedWithDifferentConnection, ex.Message); } - [ConditionalFact] - public virtual void UseTransaction_throws_if_another_transaction_started() + [ConditionalTheory] + [InlineData(true)] + [InlineData(false)] + public virtual async Task UseTransaction_is_no_op_if_same_DbTransaction_is_used(bool async) { - using var transaction = TestStore.BeginTransaction(); - using var context = CreateContextWithConnectionString(); - using (context.Database.BeginTransaction( - DirtyReadsOccur - ? IsolationLevel.ReadUncommitted - : IsolationLevel.Unspecified)) + using (var transaction = TestStore.BeginTransaction()) { - var ex = Assert.Throws( - () => - context.Database.UseTransaction(transaction)); - Assert.Equal(RelationalStrings.TransactionAlreadyStarted, ex.Message); + using var context = CreateContext(); + + var currentTransaction = async + ? await context.Database.UseTransactionAsync(transaction) + : context.Database.UseTransaction(transaction); + + Assert.Same(transaction, currentTransaction!.GetDbTransaction()); + + var newTransaction = async + ? await context.Database.UseTransactionAsync(transaction) + : context.Database.UseTransaction(transaction); + + Assert.Same(currentTransaction, newTransaction); + Assert.Same(transaction, newTransaction!.GetDbTransaction()); + + context.Entry(context.Set().OrderBy(c => c.Id).First()).State = EntityState.Deleted; + + if (async) + { + await context.SaveChangesAsync(); + } + else + { + context.SaveChanges(); + } } + + AssertStoreInitialState(); } [ConditionalFact] diff --git a/test/EFCore.SqlServer.FunctionalTests/ExecutionStrategyTest.cs b/test/EFCore.SqlServer.FunctionalTests/ExecutionStrategyTest.cs index 72f3ec9c8cf..d823184e8d6 100644 --- a/test/EFCore.SqlServer.FunctionalTests/ExecutionStrategyTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/ExecutionStrategyTest.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Data; using System.Linq; +using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.Diagnostics; @@ -381,6 +382,177 @@ public async Task Retries_SaveChanges_on_execution_failure( } } + [ConditionalTheory] // Issue #25946 + [MemberData(nameof(DataGenerator.GetBoolCombinations), 3, MemberType = typeof(DataGenerator))] + public async Task Retries_SaveChanges_on_execution_failure_with_two_contexts( + bool realFailure, + bool openConnection, + bool async) + { + CleanContext(); + + using (var context = CreateContext()) + { + using var auditContext = new AuditContext(); + + var connection = (TestSqlServerConnection)context.GetService(); + + connection.ExecutionFailures.Enqueue(new bool?[] { null, realFailure }); + + Assert.Equal(ConnectionState.Closed, context.Database.GetDbConnection().State); + + if (openConnection) + { + if (async) + { + await context.Database.OpenConnectionAsync(); + } + else + { + context.Database.OpenConnection(); + } + + Assert.Equal(ConnectionState.Open, context.Database.GetDbConnection().State); + } + + context.Products.Add(new Product()); + context.Products.Add(new Product()); + + var throwTransientError = true; + + if (async) + { + await new TestSqlServerRetryingExecutionStrategy(context).ExecuteInTransactionAsync( + (MainContext: context, AuditContext: auditContext), + async (c, ct) => + { + var result = await c.MainContext.SaveChangesAsync(false, ct); + + c.AuditContext.ChangeTracker.Clear(); + c.AuditContext.Database.SetDbConnection(c.MainContext.Database.GetDbConnection()); + + var currentTransaction = c.AuditContext.Database.CurrentTransaction; + if (throwTransientError) + { + Assert.Null(currentTransaction); + } + else + { + Assert.NotNull(currentTransaction); + } + + await c.AuditContext.Database.UseTransactionAsync( + c.MainContext.Database.CurrentTransaction!.GetDbTransaction(), ct); + + Assert.NotSame(currentTransaction, c.AuditContext.Database.CurrentTransaction); + + if (currentTransaction != null) + { + Assert.True( + (bool)typeof(RelationalTransaction).GetRuntimeFields().Single(f => f.Name == "_disposed") + .GetValue(currentTransaction)!); + } + + await c.AuditContext.Audits.AddAsync(new Audit(), ct); + await c.AuditContext.SaveChangesAsync(ct); + + if (throwTransientError) + { + throwTransientError = false; + throw SqlExceptionFactory.CreateSqlException(10928); + } + + return result; + }, + (c, _) => + { + Assert.True(false); + return Task.FromResult(false); + }); + + context.ChangeTracker.AcceptAllChanges(); + } + else + { + new TestSqlServerRetryingExecutionStrategy(context).ExecuteInTransaction( + (MainContext: context, AuditContext: auditContext), + c => + { + var result = c.MainContext.SaveChanges(false); + + c.AuditContext.ChangeTracker.Clear(); + c.AuditContext.Database.SetDbConnection(c.MainContext.Database.GetDbConnection()); + + var currentTransaction = c.AuditContext.Database.CurrentTransaction; + if (throwTransientError) + { + Assert.Null(currentTransaction); + } + else + { + Assert.NotNull(currentTransaction); + } + + c.AuditContext.Database.UseTransaction(c.MainContext.Database.CurrentTransaction!.GetDbTransaction()); + + Assert.NotSame(currentTransaction, c.AuditContext.Database.CurrentTransaction); + + if (currentTransaction != null) + { + Assert.True( + (bool)typeof(RelationalTransaction).GetRuntimeFields().Single(f => f.Name == "_disposed") + .GetValue(currentTransaction)!); + } + + c.AuditContext.Audits.Add(new Audit()); + c.AuditContext.SaveChanges(); + + if (throwTransientError) + { + throwTransientError = false; + throw SqlExceptionFactory.CreateSqlException(10928); + } + + return result; + }, + c => + { + Assert.True(false); + return false; + }); + + context.ChangeTracker.AcceptAllChanges(); + } + + Assert.Equal(openConnection ? 2 : 3, connection.OpenCount); + Assert.Equal(6, connection.ExecutionCount); + + Assert.Equal( + openConnection + ? ConnectionState.Open + : ConnectionState.Closed, context.Database.GetDbConnection().State); + + if (openConnection) + { + if (async) + { + await context.Database.CloseConnectionAsync(); + } + else + { + context.Database.CloseConnection(); + } + } + + Assert.Equal(ConnectionState.Closed, context.Database.GetDbConnection().State); + } + + using (var context = CreateContext()) + { + Assert.Equal(2, context.Products.Count()); + } + } + [ConditionalTheory] [MemberData(nameof(DataGenerator.GetBoolCombinations), 2, MemberType = typeof(DataGenerator))] public async Task Retries_query_on_execution_failure(bool externalStrategy, bool async) @@ -629,6 +801,7 @@ public ExecutionStrategyContext(DbContextOptions options) } public DbSet Products { get; set; } + public DbSet Audits { get; set; } } protected class Product @@ -637,6 +810,19 @@ protected class Product public string Name { get; set; } } + public class AuditContext : DbContext + { + public DbSet Audits { get; set; } + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + => optionsBuilder.UseSqlServer(); + } + + public class Audit + { + public int AuditId { get; set; } + } + protected virtual ExecutionStrategyContext CreateContext() => (ExecutionStrategyContext)Fixture.CreateContext();