Skip to content

Commit

Permalink
Change AutomaticProgressReporter to accept a connection
Browse files Browse the repository at this point in the history
  • Loading branch information
dtivel committed May 24, 2017
1 parent a767c82 commit 8c10f33
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 57 deletions.
24 changes: 13 additions & 11 deletions src/NuGet.Core/NuGet.Protocol/Plugins/AutomaticProgressReporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ public sealed class AutomaticProgressReporter : IDisposable
{
private readonly CancellationToken _cancellationToken;
private readonly CancellationTokenSource _cancellationTokenSource;
private readonly IConnection _connection;
private bool _isDisposed;
private readonly IPlugin _plugin;
private readonly Message _request;
private readonly SemaphoreSlim _semaphore;
private readonly Timer _timer;

private AutomaticProgressReporter(
IPlugin plugin,
IConnection connection,
Message request,
TimeSpan interval,
CancellationToken cancellationToken)
{
_plugin = plugin;
_connection = connection;
_request = request;
_cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_cancellationToken = _cancellationTokenSource.Token;
Expand Down Expand Up @@ -72,7 +72,8 @@ public void Dispose()
// does not fire after Dispose(). Otherwise, a progress notification might be sent after a
// response, which would be a fatal plugin protocol error.
_timer.Dispose();
_plugin.Dispose();

// Do not dispose of _connection. It is still in use by a plugin.

GC.SuppressFinalize(this);

Expand All @@ -94,11 +95,12 @@ public void Dispose()
/// <summary>
/// Creates a new <see cref="AutomaticProgressReporter" /> class.
/// </summary>
/// <param name="plugin">A plugin.</param>
/// <remarks>This class does not take ownership of and dispose of <paramref name="connection" />.</remarks>
/// <param name="connection">A connection.</param>
/// <param name="request">A request.</param>
/// <param name="interval">A progress interval.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="plugin" />
/// <exception cref="ArgumentNullException">Thrown if <paramref name="connection" />
/// is <c>null</c>.</exception>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="request" />
/// is <c>null</c>.</exception>
Expand All @@ -108,14 +110,14 @@ public void Dispose()
/// <exception cref="OperationCanceledException">Thrown if <paramref name="cancellationToken" />
/// is cancelled.</exception>
public static AutomaticProgressReporter Create(
IPlugin plugin,
IConnection connection,
Message request,
TimeSpan interval,
CancellationToken cancellationToken)
{
if (plugin == null)
if (connection == null)
{
throw new ArgumentNullException(nameof(plugin));
throw new ArgumentNullException(nameof(connection));
}

if (request == null)
Expand All @@ -134,7 +136,7 @@ public static AutomaticProgressReporter Create(
cancellationToken.ThrowIfCancellationRequested();

return new AutomaticProgressReporter(
plugin,
connection,
request,
interval,
cancellationToken);
Expand Down Expand Up @@ -171,7 +173,7 @@ private void OnTimer(object state)
_request.Method,
new Progress());
await _plugin.Connection.SendAsync(progress, _cancellationToken);
await _connection.SendAsync(progress, _cancellationToken);
}
catch (Exception)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public async Task HandleResponseAsync(
NetworkCredential credential = null;

using (var progressReporter = AutomaticProgressReporter.Create(
_plugin,
_plugin.Connection,
request,
PluginConstants.ProgressInterval,
cancellationToken))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Moq;
Expand All @@ -14,18 +12,18 @@ namespace NuGet.Protocol.Plugins.Tests
public class AutomaticProgressReporterTests
{
[Fact]
public void Create_ThrowsForNullPlugin()
public void Create_ThrowsForNullConnection()
{
using (var test = new AutomaticProgressReporterTest())
{
var exception = Assert.Throws<ArgumentNullException>(
() => AutomaticProgressReporter.Create(
plugin: null,
connection: null,
request: test.Request,
interval: test.Interval,
cancellationToken: CancellationToken.None));

Assert.Equal("plugin", exception.ParamName);
Assert.Equal("connection", exception.ParamName);
}
}

Expand All @@ -36,7 +34,7 @@ public void Create_ThrowsForNullRequest()
{
var exception = Assert.Throws<ArgumentNullException>(
() => AutomaticProgressReporter.Create(
test.Plugin.Object,
test.Connection.Object,
request: null,
interval: test.Interval,
cancellationToken: CancellationToken.None));
Expand Down Expand Up @@ -72,7 +70,7 @@ public void Create_ThrowsIfCancelled()
{
Assert.Throws<OperationCanceledException>(
() => AutomaticProgressReporter.Create(
test.Plugin.Object,
test.Connection.Object,
test.Request,
test.Interval,
new CancellationToken(canceled: true)));
Expand All @@ -87,6 +85,27 @@ public void Dispose_DisposesDisposables()
}
}

[Fact]
public void Dispose_DoesNotDisposeConnection()
{
var connection = new Mock<IConnection>(MockBehavior.Strict);
var request = MessageUtilities.Create(
requestId: "a",
type: MessageType.Request,
method: MessageMethod.GetServiceIndex,
payload: new GetServiceIndexRequest(packageSourceRepository: "https://unit.test"));

using (var reporter = AutomaticProgressReporter.Create(
connection.Object,
request,
TimeSpan.FromHours(1),
CancellationToken.None))
{
}

connection.Verify();
}

[Fact]
public void Dispose_IsIdempotent()
{
Expand Down Expand Up @@ -114,7 +133,7 @@ private static void VerifyInvalidInterval(TimeSpan interval)
{
var exception = Assert.Throws<ArgumentOutOfRangeException>(
() => AutomaticProgressReporter.Create(
test.Plugin.Object,
test.Connection.Object,
test.Request,
interval,
CancellationToken.None));
Expand All @@ -126,13 +145,12 @@ private static void VerifyInvalidInterval(TimeSpan interval)
private sealed class AutomaticProgressReporterTest : IDisposable
{
private int _actualSentCount;
private readonly Mock<IConnection> _connection;
private readonly int _expectedSentCount;
private bool _isDisposed;

internal CancellationTokenSource CancellationTokenSource { get; }
internal Mock<IConnection> Connection { get; }
internal TimeSpan Interval { get; }
internal Mock<IPlugin> Plugin { get; }
internal AutomaticProgressReporter Reporter { get; }
internal Message Request { get; }
internal ManualResetEventSlim SentEvent { get; }
Expand All @@ -148,13 +166,10 @@ internal AutomaticProgressReporterTest(TimeSpan? interval = null, int expectedSe
CancellationTokenSource = new CancellationTokenSource();
Interval = interval.HasValue ? interval.Value : ProtocolConstants.MaxTimeout;
SentEvent = new ManualResetEventSlim(initialState: false);
Plugin = new Mock<IPlugin>(MockBehavior.Strict);

Plugin.Setup(x => x.Dispose());
Connection = new Mock<IConnection>(MockBehavior.Strict);

_connection = new Mock<IConnection>(MockBehavior.Strict);

_connection.Setup(x => x.SendAsync(
Connection.Setup(x => x.SendAsync(
It.IsNotNull<Message>(),
It.IsAny<CancellationToken>()))
.Callback<Message, CancellationToken>(
Expand All @@ -169,16 +184,13 @@ internal AutomaticProgressReporterTest(TimeSpan? interval = null, int expectedSe
})
.Returns(Task.FromResult(0));

Plugin.SetupGet(x => x.Connection)
.Returns(_connection.Object);

Request = MessageUtilities.Create(
requestId: "a",
type: MessageType.Request,
method: MessageMethod.Handshake,
payload: payload);
Reporter = AutomaticProgressReporter.Create(
Plugin.Object,
Connection.Object,
Request,
Interval,
CancellationTokenSource.Token);
Expand Down Expand Up @@ -207,10 +219,9 @@ public void Dispose()

var connectionTimes = _expectedSentCount == 0 ? Times.Never() : Times.AtLeast(_expectedSentCount);

_connection.Verify(x => x.SendAsync(
Connection.Verify(x => x.SendAsync(
It.IsNotNull<Message>(),
It.IsAny<CancellationToken>()), connectionTimes);
Plugin.Verify(x => x.Dispose(), Times.Once);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,10 @@ public async Task HandleResponseAsync_ReturnsNotFoundIfPackageSourceNotFound()
It.IsAny<CancellationToken>()))
.ReturnsAsync((ICredentials)null);

var plugin = CreateMockPlugin();

using (var provider = new GetCredentialsRequestHandler(
Mock.Of<IPlugin>(),
plugin,
Mock.Of<IWebProxy>(),
credentialService.Object))
{
Expand Down Expand Up @@ -294,10 +296,9 @@ public async Task HandleResponseAsync_ReturnsPackageSourceCredentialsFromCredent
It.IsAny<CancellationToken>()))
.Returns(Task.FromResult(credentials.Object));

using (var provider = new GetCredentialsRequestHandler(
Mock.Of<IPlugin>(),
proxy,
credentialService.Object))
var plugin = CreateMockPlugin();

using (var provider = new GetCredentialsRequestHandler(plugin, proxy, credentialService.Object))
{
var request = CreateRequest(
MessageType.Request,
Expand All @@ -324,8 +325,10 @@ await provider.HandleResponseAsync(
[Fact]
public async Task HandleResponseAsync_ReturnsNullPackageSourceCredentialsIfPackageSourceCredentialsAreInvalidAndCredentialServiceIsNull()
{
var plugin = CreateMockPlugin();

using (var provider = new GetCredentialsRequestHandler(
Mock.Of<IPlugin>(),
plugin,
Mock.Of<IWebProxy>(),
credentialService: null))
{
Expand Down Expand Up @@ -371,10 +374,9 @@ public async Task HandleResponseAsync_ReturnsNullPackageSourceCredentialsIfNoCre
It.IsAny<CancellationToken>()))
.Returns(Task.FromResult(credentials.Object));

using (var provider = new GetCredentialsRequestHandler(
Mock.Of<IPlugin>(),
proxy,
credentialService.Object))
var plugin = CreateMockPlugin();

using (var provider = new GetCredentialsRequestHandler(plugin, proxy, credentialService.Object))
{
var request = CreateRequest(
MessageType.Request,
Expand Down Expand Up @@ -423,10 +425,9 @@ public async Task HandleResponseAsync_ReturnsProxyCredentialsFromCredentialServi
It.IsAny<CancellationToken>()))
.Returns(Task.FromResult(credentials.Object));

using (var provider = new GetCredentialsRequestHandler(
Mock.Of<IPlugin>(),
proxy.Object,
credentialService.Object))
var plugin = CreateMockPlugin();

using (var provider = new GetCredentialsRequestHandler(plugin, proxy.Object, credentialService.Object))
{
var request = CreateRequest(
MessageType.Request,
Expand All @@ -453,8 +454,10 @@ await provider.HandleResponseAsync(
[Fact]
public async Task HandleResponseAsync_ReturnsNullProxyCredentialsIfCredentialServiceIsNull()
{
var plugin = CreateMockPlugin();

using (var provider = new GetCredentialsRequestHandler(
Mock.Of<IPlugin>(),
plugin,
Mock.Of<IWebProxy>(),
credentialService: null))
{
Expand Down Expand Up @@ -500,10 +503,9 @@ public async Task HandleResponseAsync_ReturnsNullProxyCredentialsIfNoCredentials
It.IsAny<CancellationToken>()))
.Returns(Task.FromResult(credentials.Object));

using (var provider = new GetCredentialsRequestHandler(
Mock.Of<IPlugin>(),
proxy,
credentialService.Object))
var plugin = CreateMockPlugin();

using (var provider = new GetCredentialsRequestHandler(plugin, proxy, credentialService.Object))
{
var request = CreateRequest(
MessageType.Request,
Expand All @@ -530,8 +532,10 @@ await provider.HandleResponseAsync(
[Fact]
public async Task HandleResponseAsync_ReturnsNullProxyCredentialsIfNoProxy()
{
var plugin = CreateMockPlugin();

using (var provider = new GetCredentialsRequestHandler(
Mock.Of<IPlugin>(),
plugin,
proxy: null,
credentialService: Mock.Of<ICredentialService>()))
{
Expand Down Expand Up @@ -559,10 +563,19 @@ await provider.HandleResponseAsync(

private GetCredentialsRequestHandler CreateDefaultRequestHandler()
{
return new GetCredentialsRequestHandler(
Mock.Of<IPlugin>(),
Mock.Of<IWebProxy>(),
Mock.Of<ICredentialService>());
var plugin = CreateMockPlugin();

return new GetCredentialsRequestHandler(plugin, Mock.Of<IWebProxy>(), Mock.Of<ICredentialService>());
}

private static IPlugin CreateMockPlugin()
{
var plugin = new Mock<IPlugin>();

plugin.SetupGet(x => x.Connection)
.Returns(Mock.Of<IConnection>());

return plugin.Object;
}

private static Message CreateRequest(MessageType type, GetCredentialsRequest payload = null)
Expand Down

0 comments on commit 8c10f33

Please sign in to comment.