diff --git a/src/NuGet.Core/NuGet.Protocol/Plugins/AutomaticProgressReporter.cs b/src/NuGet.Core/NuGet.Protocol/Plugins/AutomaticProgressReporter.cs
index 741f2ce7f2e..3a421b7357c 100644
--- a/src/NuGet.Core/NuGet.Protocol/Plugins/AutomaticProgressReporter.cs
+++ b/src/NuGet.Core/NuGet.Protocol/Plugins/AutomaticProgressReporter.cs
@@ -14,19 +14,19 @@ public sealed class AutomaticProgressReporter : IDisposable
{
private readonly CancellationToken _cancellationToken;
private readonly CancellationTokenSource _cancellationTokenSource;
+ private readonly IConnection _connection;
private bool _isDisposed;
- private readonly IPlugin _plugin;
private readonly Message _request;
private readonly SemaphoreSlim _semaphore;
private readonly Timer _timer;
private AutomaticProgressReporter(
- IPlugin plugin,
+ IConnection connection,
Message request,
TimeSpan interval,
CancellationToken cancellationToken)
{
- _plugin = plugin;
+ _connection = connection;
_request = request;
_cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_cancellationToken = _cancellationTokenSource.Token;
@@ -72,7 +72,8 @@ public void Dispose()
// does not fire after Dispose(). Otherwise, a progress notification might be sent after a
// response, which would be a fatal plugin protocol error.
_timer.Dispose();
- _plugin.Dispose();
+
+ // Do not dispose of _connection. It is still in use by a plugin.
GC.SuppressFinalize(this);
@@ -94,11 +95,12 @@ public void Dispose()
///
/// Creates a new class.
///
- /// A plugin.
+ /// This class does not take ownership of and dispose of .
+ /// A connection.
/// A request.
/// A progress interval.
/// A cancellation token.
- /// Thrown if
+ /// Thrown if
/// is null.
/// Thrown if
/// is null.
@@ -108,14 +110,14 @@ public void Dispose()
/// Thrown if
/// is cancelled.
public static AutomaticProgressReporter Create(
- IPlugin plugin,
+ IConnection connection,
Message request,
TimeSpan interval,
CancellationToken cancellationToken)
{
- if (plugin == null)
+ if (connection == null)
{
- throw new ArgumentNullException(nameof(plugin));
+ throw new ArgumentNullException(nameof(connection));
}
if (request == null)
@@ -134,7 +136,7 @@ public static AutomaticProgressReporter Create(
cancellationToken.ThrowIfCancellationRequested();
return new AutomaticProgressReporter(
- plugin,
+ connection,
request,
interval,
cancellationToken);
@@ -171,7 +173,7 @@ private void OnTimer(object state)
_request.Method,
new Progress());
- await _plugin.Connection.SendAsync(progress, _cancellationToken);
+ await _connection.SendAsync(progress, _cancellationToken);
}
catch (Exception)
{
diff --git a/src/NuGet.Core/NuGet.Protocol/Plugins/RequestHandlers/GetCredentialsRequestHandler.cs b/src/NuGet.Core/NuGet.Protocol/Plugins/RequestHandlers/GetCredentialsRequestHandler.cs
index 74399072161..be955ec2ca0 100644
--- a/src/NuGet.Core/NuGet.Protocol/Plugins/RequestHandlers/GetCredentialsRequestHandler.cs
+++ b/src/NuGet.Core/NuGet.Protocol/Plugins/RequestHandlers/GetCredentialsRequestHandler.cs
@@ -161,7 +161,7 @@ public async Task HandleResponseAsync(
NetworkCredential credential = null;
using (var progressReporter = AutomaticProgressReporter.Create(
- _plugin,
+ _plugin.Connection,
request,
PluginConstants.ProgressInterval,
cancellationToken))
diff --git a/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/AutomaticProgressReporterTests.cs b/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/AutomaticProgressReporterTests.cs
index 59fd1bc860d..c5052921e4f 100644
--- a/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/AutomaticProgressReporterTests.cs
+++ b/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/AutomaticProgressReporterTests.cs
@@ -2,8 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
-using System.Collections.Generic;
-using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Moq;
@@ -14,18 +12,18 @@ namespace NuGet.Protocol.Plugins.Tests
public class AutomaticProgressReporterTests
{
[Fact]
- public void Create_ThrowsForNullPlugin()
+ public void Create_ThrowsForNullConnection()
{
using (var test = new AutomaticProgressReporterTest())
{
var exception = Assert.Throws(
() => AutomaticProgressReporter.Create(
- plugin: null,
+ connection: null,
request: test.Request,
interval: test.Interval,
cancellationToken: CancellationToken.None));
- Assert.Equal("plugin", exception.ParamName);
+ Assert.Equal("connection", exception.ParamName);
}
}
@@ -36,7 +34,7 @@ public void Create_ThrowsForNullRequest()
{
var exception = Assert.Throws(
() => AutomaticProgressReporter.Create(
- test.Plugin.Object,
+ test.Connection.Object,
request: null,
interval: test.Interval,
cancellationToken: CancellationToken.None));
@@ -72,7 +70,7 @@ public void Create_ThrowsIfCancelled()
{
Assert.Throws(
() => AutomaticProgressReporter.Create(
- test.Plugin.Object,
+ test.Connection.Object,
test.Request,
test.Interval,
new CancellationToken(canceled: true)));
@@ -87,6 +85,27 @@ public void Dispose_DisposesDisposables()
}
}
+ [Fact]
+ public void Dispose_DoesNotDisposeConnection()
+ {
+ var connection = new Mock(MockBehavior.Strict);
+ var request = MessageUtilities.Create(
+ requestId: "a",
+ type: MessageType.Request,
+ method: MessageMethod.GetServiceIndex,
+ payload: new GetServiceIndexRequest(packageSourceRepository: "https://unit.test"));
+
+ using (var reporter = AutomaticProgressReporter.Create(
+ connection.Object,
+ request,
+ TimeSpan.FromHours(1),
+ CancellationToken.None))
+ {
+ }
+
+ connection.Verify();
+ }
+
[Fact]
public void Dispose_IsIdempotent()
{
@@ -114,7 +133,7 @@ private static void VerifyInvalidInterval(TimeSpan interval)
{
var exception = Assert.Throws(
() => AutomaticProgressReporter.Create(
- test.Plugin.Object,
+ test.Connection.Object,
test.Request,
interval,
CancellationToken.None));
@@ -126,13 +145,12 @@ private static void VerifyInvalidInterval(TimeSpan interval)
private sealed class AutomaticProgressReporterTest : IDisposable
{
private int _actualSentCount;
- private readonly Mock _connection;
private readonly int _expectedSentCount;
private bool _isDisposed;
internal CancellationTokenSource CancellationTokenSource { get; }
+ internal Mock Connection { get; }
internal TimeSpan Interval { get; }
- internal Mock Plugin { get; }
internal AutomaticProgressReporter Reporter { get; }
internal Message Request { get; }
internal ManualResetEventSlim SentEvent { get; }
@@ -148,13 +166,10 @@ internal AutomaticProgressReporterTest(TimeSpan? interval = null, int expectedSe
CancellationTokenSource = new CancellationTokenSource();
Interval = interval.HasValue ? interval.Value : ProtocolConstants.MaxTimeout;
SentEvent = new ManualResetEventSlim(initialState: false);
- Plugin = new Mock(MockBehavior.Strict);
- Plugin.Setup(x => x.Dispose());
+ Connection = new Mock(MockBehavior.Strict);
- _connection = new Mock(MockBehavior.Strict);
-
- _connection.Setup(x => x.SendAsync(
+ Connection.Setup(x => x.SendAsync(
It.IsNotNull(),
It.IsAny()))
.Callback(
@@ -169,16 +184,13 @@ internal AutomaticProgressReporterTest(TimeSpan? interval = null, int expectedSe
})
.Returns(Task.FromResult(0));
- Plugin.SetupGet(x => x.Connection)
- .Returns(_connection.Object);
-
Request = MessageUtilities.Create(
requestId: "a",
type: MessageType.Request,
method: MessageMethod.Handshake,
payload: payload);
Reporter = AutomaticProgressReporter.Create(
- Plugin.Object,
+ Connection.Object,
Request,
Interval,
CancellationTokenSource.Token);
@@ -207,10 +219,9 @@ public void Dispose()
var connectionTimes = _expectedSentCount == 0 ? Times.Never() : Times.AtLeast(_expectedSentCount);
- _connection.Verify(x => x.SendAsync(
+ Connection.Verify(x => x.SendAsync(
It.IsNotNull(),
It.IsAny()), connectionTimes);
- Plugin.Verify(x => x.Dispose(), Times.Once);
}
}
}
diff --git a/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/RequestHandlers/GetCredentialsRequestHandlerTests.cs b/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/RequestHandlers/GetCredentialsRequestHandlerTests.cs
index e09dba5a8f5..8fe4d7e095d 100644
--- a/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/RequestHandlers/GetCredentialsRequestHandlerTests.cs
+++ b/test/NuGet.Core.Tests/NuGet.Protocol.Tests/Plugins/RequestHandlers/GetCredentialsRequestHandlerTests.cs
@@ -209,8 +209,10 @@ public async Task HandleResponseAsync_ReturnsNotFoundIfPackageSourceNotFound()
It.IsAny()))
.ReturnsAsync((ICredentials)null);
+ var plugin = CreateMockPlugin();
+
using (var provider = new GetCredentialsRequestHandler(
- Mock.Of(),
+ plugin,
Mock.Of(),
credentialService.Object))
{
@@ -294,10 +296,9 @@ public async Task HandleResponseAsync_ReturnsPackageSourceCredentialsFromCredent
It.IsAny()))
.Returns(Task.FromResult(credentials.Object));
- using (var provider = new GetCredentialsRequestHandler(
- Mock.Of(),
- proxy,
- credentialService.Object))
+ var plugin = CreateMockPlugin();
+
+ using (var provider = new GetCredentialsRequestHandler(plugin, proxy, credentialService.Object))
{
var request = CreateRequest(
MessageType.Request,
@@ -324,8 +325,10 @@ await provider.HandleResponseAsync(
[Fact]
public async Task HandleResponseAsync_ReturnsNullPackageSourceCredentialsIfPackageSourceCredentialsAreInvalidAndCredentialServiceIsNull()
{
+ var plugin = CreateMockPlugin();
+
using (var provider = new GetCredentialsRequestHandler(
- Mock.Of(),
+ plugin,
Mock.Of(),
credentialService: null))
{
@@ -371,10 +374,9 @@ public async Task HandleResponseAsync_ReturnsNullPackageSourceCredentialsIfNoCre
It.IsAny()))
.Returns(Task.FromResult(credentials.Object));
- using (var provider = new GetCredentialsRequestHandler(
- Mock.Of(),
- proxy,
- credentialService.Object))
+ var plugin = CreateMockPlugin();
+
+ using (var provider = new GetCredentialsRequestHandler(plugin, proxy, credentialService.Object))
{
var request = CreateRequest(
MessageType.Request,
@@ -423,10 +425,9 @@ public async Task HandleResponseAsync_ReturnsProxyCredentialsFromCredentialServi
It.IsAny()))
.Returns(Task.FromResult(credentials.Object));
- using (var provider = new GetCredentialsRequestHandler(
- Mock.Of(),
- proxy.Object,
- credentialService.Object))
+ var plugin = CreateMockPlugin();
+
+ using (var provider = new GetCredentialsRequestHandler(plugin, proxy.Object, credentialService.Object))
{
var request = CreateRequest(
MessageType.Request,
@@ -453,8 +454,10 @@ await provider.HandleResponseAsync(
[Fact]
public async Task HandleResponseAsync_ReturnsNullProxyCredentialsIfCredentialServiceIsNull()
{
+ var plugin = CreateMockPlugin();
+
using (var provider = new GetCredentialsRequestHandler(
- Mock.Of(),
+ plugin,
Mock.Of(),
credentialService: null))
{
@@ -500,10 +503,9 @@ public async Task HandleResponseAsync_ReturnsNullProxyCredentialsIfNoCredentials
It.IsAny()))
.Returns(Task.FromResult(credentials.Object));
- using (var provider = new GetCredentialsRequestHandler(
- Mock.Of(),
- proxy,
- credentialService.Object))
+ var plugin = CreateMockPlugin();
+
+ using (var provider = new GetCredentialsRequestHandler(plugin, proxy, credentialService.Object))
{
var request = CreateRequest(
MessageType.Request,
@@ -530,8 +532,10 @@ await provider.HandleResponseAsync(
[Fact]
public async Task HandleResponseAsync_ReturnsNullProxyCredentialsIfNoProxy()
{
+ var plugin = CreateMockPlugin();
+
using (var provider = new GetCredentialsRequestHandler(
- Mock.Of(),
+ plugin,
proxy: null,
credentialService: Mock.Of()))
{
@@ -559,10 +563,19 @@ await provider.HandleResponseAsync(
private GetCredentialsRequestHandler CreateDefaultRequestHandler()
{
- return new GetCredentialsRequestHandler(
- Mock.Of(),
- Mock.Of(),
- Mock.Of());
+ var plugin = CreateMockPlugin();
+
+ return new GetCredentialsRequestHandler(plugin, Mock.Of(), Mock.Of());
+ }
+
+ private static IPlugin CreateMockPlugin()
+ {
+ var plugin = new Mock();
+
+ plugin.SetupGet(x => x.Connection)
+ .Returns(Mock.Of());
+
+ return plugin.Object;
}
private static Message CreateRequest(MessageType type, GetCredentialsRequest payload = null)