diff --git a/src/Microsoft.TestPlatform.CommunicationUtilities/TestRequestSender.cs b/src/Microsoft.TestPlatform.CommunicationUtilities/TestRequestSender.cs index ed724bc1fe..ac039c0776 100644 --- a/src/Microsoft.TestPlatform.CommunicationUtilities/TestRequestSender.cs +++ b/src/Microsoft.TestPlatform.CommunicationUtilities/TestRequestSender.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Globalization; using System.Threading; +using System.Threading.Tasks; using Microsoft.VisualStudio.TestPlatform.CommunicationUtilities.Interfaces; using Microsoft.VisualStudio.TestPlatform.CommunicationUtilities.ObjectModel; @@ -40,6 +41,7 @@ public class TestRequestSender : ITestRequestSender private ICommunicationChannel? _channel; private EventHandler? _onMessageReceived; + private DisconnectedEventArgs? _disconnectedInfo; private Action? _onDisconnected; // Set to 1 if Discovery/Execution is complete, i.e. complete handlers have been invoked private int _operationCompleted; @@ -151,8 +153,14 @@ public int InitializeCommunication() }; _communicationEndpoint.Disconnected += (sender, args) => + { + // Store the disconnected info, so that any further DiscoverTests, + // RunTests methods can immediately bail. + _disconnectedInfo = args; + // If there's an disconnected event handler, call it - _onDisconnected?.Invoke(args); + InvokeDisconnectedHandler(args); + }; // Server start returns the listener port // return int.Parse(this.communicationServer.Start()); @@ -161,6 +169,46 @@ public int InitializeCommunication() return endpoint.GetIpEndPoint().Port; } + + private bool TrySetupMessageReceiver( + EventHandler onMessageReceived, + Action onDisconnected) + { + TPDebug.Assert(_channel is not null, "_channel is null"); + + // Note: Attempts to setup a message receiver. + // It's possible that the testhost was already disconnected and in that case we should + // immediately call the disconnected callback. + + // Design: The current method is needed because the request sender sets up + // the disconnect handler late. If the first thing that is done by the class + // is to setup the disconnect handler, then we'd only need to fire the handler + // when the disconnect event fires. + + _onDisconnected = onDisconnected; + + // If the testhost was already disconnected, trigger the handler immediately. + if (_disconnectedInfo is DisconnectedEventArgs args) + { + InvokeDisconnectedHandler(args); + return false; + } + + _onMessageReceived = onMessageReceived; + _channel.MessageReceived += _onMessageReceived; + + return true; + } + + private void InvokeDisconnectedHandler(DisconnectedEventArgs args) + { + // Note: If the endpoint is disconnected at the same time as the + // disconnected handler is setup, it's possible for this method + // to be invoked twice. Ensure that the handler ever gets invoked once. + var handler = Interlocked.Exchange(ref _onDisconnected, null); + handler?.Invoke(args); + } + /// public bool WaitForRequestHandlerConnection(int connectionTimeout, CancellationToken cancellationToken) { @@ -189,7 +237,8 @@ public void CheckVersionWithTestHost() // Test host sends back the lower number of the two. So the highest protocol version, that both sides support is used. // Error case: test host can send a protocol error if it cannot find a supported version var protocolNegotiated = new ManualResetEvent(false); - _onMessageReceived = (sender, args) => + + EventHandler onMessageReceived = (sender, args) => { var message = _dataSerializer.DeserializeMessage(args.Data!); @@ -221,7 +270,7 @@ public void CheckVersionWithTestHost() protocolNegotiated.Set(); }; - _channel.MessageReceived += _onMessageReceived; + _channel.MessageReceived += onMessageReceived; try { @@ -242,8 +291,7 @@ public void CheckVersionWithTestHost() } finally { - _channel.MessageReceived -= _onMessageReceived; - _onMessageReceived = null; + _channel.MessageReceived -= onMessageReceived; } } @@ -270,10 +318,13 @@ public void DiscoverTests(DiscoveryCriteria discoveryCriteria, ITestDiscoveryEve _messageEventHandler = discoveryEventsHandler; // When testhost disconnects, it normally means there was an error in the testhost and it exited unexpectedly. // But when it was us who aborted the run and killed the testhost, we don't want to wait for it to report error, because there won't be any. - _onDisconnected = disconnectedEventArgs => OnDiscoveryAbort(discoveryEventsHandler, disconnectedEventArgs.Error, getClientError: !_isDiscoveryAborted); - _onMessageReceived = (sender, args) => OnDiscoveryMessageReceived(discoveryEventsHandler, args); + if (!TrySetupMessageReceiver( + onMessageReceived: (_, args) => OnDiscoveryMessageReceived(discoveryEventsHandler, args), + onDisconnected: disconnectedEventArgs => OnDiscoveryAbort(discoveryEventsHandler, disconnectedEventArgs.Error, getClientError: !_isDiscoveryAborted))) + { + return; + } - _channel.MessageReceived += _onMessageReceived; var message = _dataSerializer.SerializePayload( MessageType.StartDiscovery, discoveryCriteria, @@ -320,10 +371,13 @@ public void StartTestRun(TestRunCriteriaWithSources runCriteria, IInternalTestRu { TPDebug.Assert(_channel is not null, "_channel is null"); _messageEventHandler = eventHandler; - _onDisconnected = (disconnectedEventArgs) => OnTestRunAbort(eventHandler, disconnectedEventArgs.Error, true); - _onMessageReceived = (sender, args) => OnExecutionMessageReceived(args, eventHandler); - _channel.MessageReceived += _onMessageReceived; + if (!TrySetupMessageReceiver( + onMessageReceived: (_, args) => OnExecutionMessageReceived(args, eventHandler), + onDisconnected: disconnectedEventArgs => OnTestRunAbort(eventHandler, disconnectedEventArgs.Error, true))) + { + return; + } // This code section is needed because we altered the old testhost launch process for // the debugging workflow. Now we don't ask VS to launch and attach to the testhost @@ -360,10 +414,13 @@ public void StartTestRun(TestRunCriteriaWithTests runCriteria, IInternalTestRunE { TPDebug.Assert(_channel is not null, "_channel is null"); _messageEventHandler = eventHandler; - _onDisconnected = (disconnectedEventArgs) => OnTestRunAbort(eventHandler, disconnectedEventArgs.Error, true); - _onMessageReceived = (sender, args) => OnExecutionMessageReceived(args, eventHandler); - _channel.MessageReceived += _onMessageReceived; + if (!TrySetupMessageReceiver( + onMessageReceived: (_, args) => OnExecutionMessageReceived(args, eventHandler), + onDisconnected: disconnectedEventArgs => OnTestRunAbort(eventHandler, disconnectedEventArgs.Error, true))) + { + return; + } // This code section is needed because we altered the old testhost launch process for // the debugging workflow. Now we don't ask VS to launch and attach to the testhost diff --git a/test/Microsoft.TestPlatform.CommunicationUtilities.UnitTests/TestRequestSenderTests.cs b/test/Microsoft.TestPlatform.CommunicationUtilities.UnitTests/TestRequestSenderTests.cs index 892b49ed8c..6212a426c5 100644 --- a/test/Microsoft.TestPlatform.CommunicationUtilities.UnitTests/TestRequestSenderTests.cs +++ b/test/Microsoft.TestPlatform.CommunicationUtilities.UnitTests/TestRequestSenderTests.cs @@ -8,6 +8,7 @@ using System.Net; using System.Net.Sockets; using System.Threading; +using System.Threading.Tasks; using Microsoft.VisualStudio.TestPlatform.CommunicationUtilities; using Microsoft.VisualStudio.TestPlatform.CommunicationUtilities.Interfaces; @@ -470,6 +471,18 @@ public void DiscoverTestShouldNotifyLogMessageIfClientDisconnectedWithClientExit _mockDiscoveryEventsHandler.Verify(eh => eh.HandleRawMessage(It.Is(s => !string.IsNullOrEmpty(s) && s.Equals("Serialized Stderr"))), Times.Once); } + [TestMethod] + public void DiscoverTestShouldNotifyDiscoveryCompleteIfClientDisconnectedBeforeDiscovery() + { + SetupFakeCommunicationChannel(); + + RaiseClientDisconnectedEvent(); + + _testRequestSender.DiscoverTests(new DiscoveryCriteria(), _mockDiscoveryEventsHandler.Object); + + _mockDiscoveryEventsHandler.Verify(eh => eh.HandleDiscoveryComplete(It.Is(dc => dc.IsAborted == true && dc.TotalCount == -1), null)); + } + [TestMethod] public void DiscoverTestShouldNotifyDiscoveryCompleteIfClientDisconnected() { @@ -746,6 +759,52 @@ public void StartTestRunShouldNotifyErrorLogMessageIfClientDisconnectedWithClien _mockExecutionEventsHandler.Verify(eh => eh.HandleLogMessage(TestMessageLevel.Error, It.Is(s => s.Contains(expectedErrorMessage))), Times.Once); } + [TestMethod] + public void StartTestRunShouldNotifyExecutionCompleteIfClientDisconnectedBeforeRun() + { + SetupOperationAbortedPayload(); + SetupFakeCommunicationChannel(); + + RaiseClientDisconnectedEvent(); + + _testRequestSender.StartTestRun(_testRunCriteriaWithSources, _mockExecutionEventsHandler.Object); + + _mockExecutionEventsHandler.Verify(eh => eh.HandleTestRunComplete(It.Is(t => t.IsAborted), null, null, null), Times.Once); + _mockExecutionEventsHandler.Verify(eh => eh.HandleRawMessage("SerializedAbortedPayload"), Times.Once); + } + + [TestMethod] + public void StartTestRunWithTestsShouldNotifyExecutionCompleteIfClientDisconnectedBeforeRun() + { + var runCriteria = new TestRunCriteriaWithTests(new TestCase[2], "runsettings", null, null!); + SetupOperationAbortedPayload(); + SetupFakeCommunicationChannel(); + + RaiseClientDisconnectedEvent(); + + _testRequestSender.StartTestRun(runCriteria, _mockExecutionEventsHandler.Object); + + _mockExecutionEventsHandler.Verify(eh => eh.HandleTestRunComplete(It.Is(t => t.IsAborted), null, null, null), Times.Once); + _mockExecutionEventsHandler.Verify(eh => eh.HandleRawMessage("SerializedAbortedPayload"), Times.Once); + } + + [TestMethod] + public async Task StartTestRunWithTestsShouldNotifyExecutionCompleteIfClientDisconnectedBeforeRunInAThreadSafeWay() + { + var runCriteria = new TestRunCriteriaWithTests(new TestCase[2], "runsettings", null, null!); + SetupOperationAbortedPayload(); + SetupFakeCommunicationChannel(); + + // Note: Even if the calls get invoked on separate threads, the request sender should send back the complete message just once. + var t1 = Task.Run(RaiseClientDisconnectedEvent); + var t2 = Task.Run(() => _testRequestSender.StartTestRun(runCriteria, _mockExecutionEventsHandler.Object)); + + await Task.WhenAll(t1, t2); + + _mockExecutionEventsHandler.Verify(eh => eh.HandleTestRunComplete(It.Is(t => t.IsAborted), null, null, null), Times.Once); + _mockExecutionEventsHandler.Verify(eh => eh.HandleRawMessage("SerializedAbortedPayload"), Times.Once); + } + [TestMethod] public void StartTestRunShouldNotifyExecutionCompleteIfClientDisconnected() {