diff --git a/src/NuGet.Core/NuGet.Protocol/Plugins/AutomaticProgressReporter.cs b/src/NuGet.Core/NuGet.Protocol/Plugins/AutomaticProgressReporter.cs index 741f2ce7f2e..3a421b7357c 100644 --- a/src/NuGet.Core/NuGet.Protocol/Plugins/AutomaticProgressReporter.cs +++ b/src/NuGet.Core/NuGet.Protocol/Plugins/AutomaticProgressReporter.cs @@ -14,19 +14,19 @@ public sealed class AutomaticProgressReporter : IDisposable { private readonly CancellationToken _cancellationToken; private readonly CancellationTokenSource _cancellationTokenSource; + private readonly IConnection _connection; private bool _isDisposed; - private readonly IPlugin _plugin; private readonly Message _request; private readonly SemaphoreSlim _semaphore; private readonly Timer _timer; private AutomaticProgressReporter( - IPlugin plugin, + IConnection connection, Message request, TimeSpan interval, CancellationToken cancellationToken) { - _plugin = plugin; + _connection = connection; _request = request; _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); _cancellationToken = _cancellationTokenSource.Token; @@ -72,7 +72,8 @@ public void Dispose() // does not fire after Dispose(). Otherwise, a progress notification might be sent after a // response, which would be a fatal plugin protocol error. _timer.Dispose(); - _plugin.Dispose(); + + // Do not dispose of _connection. It is still in use by a plugin. GC.SuppressFinalize(this); @@ -94,11 +95,12 @@ public void Dispose() /// /// Creates a new class. /// - /// A plugin. + /// This class does not take ownership of and dispose of . + /// A connection. /// A request. /// A progress interval. /// A cancellation token. - /// Thrown if + /// Thrown if /// is null. /// Thrown if /// is null. @@ -108,14 +110,14 @@ public void Dispose() /// Thrown if /// is cancelled. public static AutomaticProgressReporter Create( - IPlugin plugin, + IConnection connection, Message request, TimeSpan interval, CancellationToken cancellationToken) { - if (plugin == null) + if (connection == null) { - throw new ArgumentNullException(nameof(plugin)); + throw new ArgumentNullException(nameof(connection)); } if (request == null) @@ -134,7 +136,7 @@ public static AutomaticProgressReporter Create( cancellationToken.ThrowIfCancellationRequested(); return new AutomaticProgressReporter( - plugin, + connection, request, interval, cancellationToken); @@ -171,7 +173,7 @@ private void OnTimer(object state) _request.Method, new Progress()); - await _plugin.Connection.SendAsync(progress, _cancellationToken); + await _connection.SendAsync(progress, _cancellationToken); } catch (Exception) { diff --git a/src/NuGet.Core/NuGet.Protocol/Plugins/RequestHandlers/GetCredentialsRequestHandler.cs b/src/NuGet.Core/NuGet.Protocol/Plugins/RequestHandlers/GetCredentialsRequestHandler.cs index 74399072161..be955ec2ca0 100644 --- a/src/NuGet.Core/NuGet.Protocol/Plugins/RequestHandlers/GetCredentialsRequestHandler.cs +++ b/src/NuGet.Core/NuGet.Protocol/Plugins/RequestHandlers/GetCredentialsRequestHandler.cs @@ -161,7 +161,7 @@ public async Task HandleResponseAsync( NetworkCredential credential = null; using (var progressReporter = AutomaticProgressReporter.Create( - _plugin, + _plugin.Connection, request, PluginConstants.ProgressInterval, cancellationToken)) diff --git a/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/AutomaticProgressReporterTests.cs b/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/AutomaticProgressReporterTests.cs index 59fd1bc860d..c5052921e4f 100644 --- a/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/AutomaticProgressReporterTests.cs +++ b/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/AutomaticProgressReporterTests.cs @@ -2,8 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; -using System.Text; using System.Threading; using System.Threading.Tasks; using Moq; @@ -14,18 +12,18 @@ namespace NuGet.Protocol.Plugins.Tests public class AutomaticProgressReporterTests { [Fact] - public void Create_ThrowsForNullPlugin() + public void Create_ThrowsForNullConnection() { using (var test = new AutomaticProgressReporterTest()) { var exception = Assert.Throws( () => AutomaticProgressReporter.Create( - plugin: null, + connection: null, request: test.Request, interval: test.Interval, cancellationToken: CancellationToken.None)); - Assert.Equal("plugin", exception.ParamName); + Assert.Equal("connection", exception.ParamName); } } @@ -36,7 +34,7 @@ public void Create_ThrowsForNullRequest() { var exception = Assert.Throws( () => AutomaticProgressReporter.Create( - test.Plugin.Object, + test.Connection.Object, request: null, interval: test.Interval, cancellationToken: CancellationToken.None)); @@ -72,7 +70,7 @@ public void Create_ThrowsIfCancelled() { Assert.Throws( () => AutomaticProgressReporter.Create( - test.Plugin.Object, + test.Connection.Object, test.Request, test.Interval, new CancellationToken(canceled: true))); @@ -87,6 +85,27 @@ public void Dispose_DisposesDisposables() } } + [Fact] + public void Dispose_DoesNotDisposeConnection() + { + var connection = new Mock(MockBehavior.Strict); + var request = MessageUtilities.Create( + requestId: "a", + type: MessageType.Request, + method: MessageMethod.GetServiceIndex, + payload: new GetServiceIndexRequest(packageSourceRepository: "https://unit.test")); + + using (var reporter = AutomaticProgressReporter.Create( + connection.Object, + request, + TimeSpan.FromHours(1), + CancellationToken.None)) + { + } + + connection.Verify(); + } + [Fact] public void Dispose_IsIdempotent() { @@ -114,7 +133,7 @@ private static void VerifyInvalidInterval(TimeSpan interval) { var exception = Assert.Throws( () => AutomaticProgressReporter.Create( - test.Plugin.Object, + test.Connection.Object, test.Request, interval, CancellationToken.None)); @@ -126,13 +145,12 @@ private static void VerifyInvalidInterval(TimeSpan interval) private sealed class AutomaticProgressReporterTest : IDisposable { private int _actualSentCount; - private readonly Mock _connection; private readonly int _expectedSentCount; private bool _isDisposed; internal CancellationTokenSource CancellationTokenSource { get; } + internal Mock Connection { get; } internal TimeSpan Interval { get; } - internal Mock Plugin { get; } internal AutomaticProgressReporter Reporter { get; } internal Message Request { get; } internal ManualResetEventSlim SentEvent { get; } @@ -148,13 +166,10 @@ internal AutomaticProgressReporterTest(TimeSpan? interval = null, int expectedSe CancellationTokenSource = new CancellationTokenSource(); Interval = interval.HasValue ? interval.Value : ProtocolConstants.MaxTimeout; SentEvent = new ManualResetEventSlim(initialState: false); - Plugin = new Mock(MockBehavior.Strict); - Plugin.Setup(x => x.Dispose()); + Connection = new Mock(MockBehavior.Strict); - _connection = new Mock(MockBehavior.Strict); - - _connection.Setup(x => x.SendAsync( + Connection.Setup(x => x.SendAsync( It.IsNotNull(), It.IsAny())) .Callback( @@ -169,16 +184,13 @@ internal AutomaticProgressReporterTest(TimeSpan? interval = null, int expectedSe }) .Returns(Task.FromResult(0)); - Plugin.SetupGet(x => x.Connection) - .Returns(_connection.Object); - Request = MessageUtilities.Create( requestId: "a", type: MessageType.Request, method: MessageMethod.Handshake, payload: payload); Reporter = AutomaticProgressReporter.Create( - Plugin.Object, + Connection.Object, Request, Interval, CancellationTokenSource.Token); @@ -207,10 +219,9 @@ public void Dispose() var connectionTimes = _expectedSentCount == 0 ? Times.Never() : Times.AtLeast(_expectedSentCount); - _connection.Verify(x => x.SendAsync( + Connection.Verify(x => x.SendAsync( It.IsNotNull(), It.IsAny()), connectionTimes); - Plugin.Verify(x => x.Dispose(), Times.Once); } } } diff --git a/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/RequestHandlers/GetCredentialsRequestHandlerTests.cs b/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/RequestHandlers/GetCredentialsRequestHandlerTests.cs index e09dba5a8f5..8fe4d7e095d 100644 --- a/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/RequestHandlers/GetCredentialsRequestHandlerTests.cs +++ b/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/RequestHandlers/GetCredentialsRequestHandlerTests.cs @@ -209,8 +209,10 @@ public async Task HandleResponseAsync_ReturnsNotFoundIfPackageSourceNotFound() It.IsAny())) .ReturnsAsync((ICredentials)null); + var plugin = CreateMockPlugin(); + using (var provider = new GetCredentialsRequestHandler( - Mock.Of(), + plugin, Mock.Of(), credentialService.Object)) { @@ -294,10 +296,9 @@ public async Task HandleResponseAsync_ReturnsPackageSourceCredentialsFromCredent It.IsAny())) .Returns(Task.FromResult(credentials.Object)); - using (var provider = new GetCredentialsRequestHandler( - Mock.Of(), - proxy, - credentialService.Object)) + var plugin = CreateMockPlugin(); + + using (var provider = new GetCredentialsRequestHandler(plugin, proxy, credentialService.Object)) { var request = CreateRequest( MessageType.Request, @@ -324,8 +325,10 @@ await provider.HandleResponseAsync( [Fact] public async Task HandleResponseAsync_ReturnsNullPackageSourceCredentialsIfPackageSourceCredentialsAreInvalidAndCredentialServiceIsNull() { + var plugin = CreateMockPlugin(); + using (var provider = new GetCredentialsRequestHandler( - Mock.Of(), + plugin, Mock.Of(), credentialService: null)) { @@ -371,10 +374,9 @@ public async Task HandleResponseAsync_ReturnsNullPackageSourceCredentialsIfNoCre It.IsAny())) .Returns(Task.FromResult(credentials.Object)); - using (var provider = new GetCredentialsRequestHandler( - Mock.Of(), - proxy, - credentialService.Object)) + var plugin = CreateMockPlugin(); + + using (var provider = new GetCredentialsRequestHandler(plugin, proxy, credentialService.Object)) { var request = CreateRequest( MessageType.Request, @@ -423,10 +425,9 @@ public async Task HandleResponseAsync_ReturnsProxyCredentialsFromCredentialServi It.IsAny())) .Returns(Task.FromResult(credentials.Object)); - using (var provider = new GetCredentialsRequestHandler( - Mock.Of(), - proxy.Object, - credentialService.Object)) + var plugin = CreateMockPlugin(); + + using (var provider = new GetCredentialsRequestHandler(plugin, proxy.Object, credentialService.Object)) { var request = CreateRequest( MessageType.Request, @@ -453,8 +454,10 @@ await provider.HandleResponseAsync( [Fact] public async Task HandleResponseAsync_ReturnsNullProxyCredentialsIfCredentialServiceIsNull() { + var plugin = CreateMockPlugin(); + using (var provider = new GetCredentialsRequestHandler( - Mock.Of(), + plugin, Mock.Of(), credentialService: null)) { @@ -500,10 +503,9 @@ public async Task HandleResponseAsync_ReturnsNullProxyCredentialsIfNoCredentials It.IsAny())) .Returns(Task.FromResult(credentials.Object)); - using (var provider = new GetCredentialsRequestHandler( - Mock.Of(), - proxy, - credentialService.Object)) + var plugin = CreateMockPlugin(); + + using (var provider = new GetCredentialsRequestHandler(plugin, proxy, credentialService.Object)) { var request = CreateRequest( MessageType.Request, @@ -530,8 +532,10 @@ await provider.HandleResponseAsync( [Fact] public async Task HandleResponseAsync_ReturnsNullProxyCredentialsIfNoProxy() { + var plugin = CreateMockPlugin(); + using (var provider = new GetCredentialsRequestHandler( - Mock.Of(), + plugin, proxy: null, credentialService: Mock.Of())) { @@ -559,10 +563,19 @@ await provider.HandleResponseAsync( private GetCredentialsRequestHandler CreateDefaultRequestHandler() { - return new GetCredentialsRequestHandler( - Mock.Of(), - Mock.Of(), - Mock.Of()); + var plugin = CreateMockPlugin(); + + return new GetCredentialsRequestHandler(plugin, Mock.Of(), Mock.Of()); + } + + private static IPlugin CreateMockPlugin() + { + var plugin = new Mock(); + + plugin.SetupGet(x => x.Connection) + .Returns(Mock.Of()); + + return plugin.Object; } private static Message CreateRequest(MessageType type, GetCredentialsRequest payload = null)