Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Adapter/MSTest.TestAdapter/Execution/TestClassInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ TestResult DoRun()
{
if (logListener is not null)
{
FixtureMethodRunner.RunOnContext(ExecutionContext, () =>
ExecutionContextHelpers.RunOnContext(ExecutionContext, () =>
{
initializationLogs += logListener.GetAndClearStandardOutput();
initializationTrace += logListener.GetAndClearDebugTrace();
Expand Down Expand Up @@ -801,7 +801,7 @@ void DoRun()
{
if (logListener is not null)
{
FixtureMethodRunner.RunOnContext(ExecutionContext, () =>
ExecutionContextHelpers.RunOnContext(ExecutionContext, () =>
{
initializationLogs = logListener.GetAndClearStandardOutput();
initializationErrorLogs = logListener.GetAndClearStandardError();
Expand Down
66 changes: 51 additions & 15 deletions src/Adapter/MSTest.TestAdapter/Execution/TestMethodInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,18 @@ public TAttributeType[] GetAttributes<TAttributeType>(bool inherit)
where TAttributeType : Attribute
=> ReflectHelper.Instance.GetDerivedAttributes<TAttributeType>(TestMethod, inherit).ToArray();

/// <inheritdoc cref="InvokeAsync(object[])" />
public virtual TestResult Invoke(object?[]? arguments)
=> InvokeAsync(arguments).GetAwaiter().GetResult();

/// <summary>
/// Execute test method. Capture failures, handle async and return result.
/// </summary>
/// <param name="arguments">
/// Arguments to pass to test method. (E.g. For data driven).
/// </param>
/// <returns>Result of test method invocation.</returns>
public virtual TestResult Invoke(object?[]? arguments)
public virtual async Task<TestResult> InvokeAsync(object?[]? arguments)
{
Stopwatch watch = new();
TestResult? result = null;
Expand All @@ -155,13 +159,16 @@ public virtual TestResult Invoke(object?[]? arguments)

try
{
FixtureMethodRunner.RunOnContext(executionContext, () =>
ExecutionContextHelpers.RunOnContext(executionContext, () =>
{
ThreadSafeStringWriter.CleanState();
listener = new LogMessageListener(MSTestSettings.CurrentSettings.CaptureDebugTraces);
executionContext = ExecutionContext.Capture();
});
result = IsTimeoutSet ? ExecuteInternalWithTimeout(arguments, executionContext) : ExecuteInternal(arguments, executionContext, null);

result = IsTimeoutSet
? await ExecuteInternalWithTimeoutAsync(arguments, executionContext)
: await ExecuteInternalAsync(arguments, executionContext, null);
}
finally
{
Expand All @@ -173,7 +180,7 @@ public virtual TestResult Invoke(object?[]? arguments)
result.Duration = watch.Elapsed;
if (listener is not null)
{
FixtureMethodRunner.RunOnContext(executionContext, () =>
ExecutionContextHelpers.RunOnContext(executionContext, () =>
{
result.DebugTrace = listener.GetAndClearDebugTrace();
result.LogOutput = listener.GetAndClearStandardOutput();
Expand Down Expand Up @@ -391,14 +398,14 @@ private void ThrowMultipleAttributesException(string attributeName)
/// <param name="timeoutTokenSource">The timeout token source.</param>
/// <returns>The result of the execution.</returns>
[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Requirement is to handle all kinds of user exceptions and message appropriately.")]
private TestResult ExecuteInternal(object?[]? arguments, ExecutionContext? executionContext, CancellationTokenSource? timeoutTokenSource)
private async Task<TestResult> ExecuteInternalAsync(object?[]? arguments, ExecutionContext? executionContext, CancellationTokenSource? timeoutTokenSource)
{
DebugEx.Assert(TestMethod != null, "UnitTestExecuter.DefaultTestMethodInvoke: testMethod = null.");

var result = new TestResult();

// TODO remove dry violation with TestMethodRunner
FixtureMethodRunner.RunOnContext(executionContext, () =>
ExecutionContextHelpers.RunOnContext(executionContext, () =>
{
_classInstance = CreateTestClassInstance(result);
executionContext = ExecutionContext.Capture();
Expand All @@ -420,11 +427,33 @@ private TestResult ExecuteInternal(object?[]? arguments, ExecutionContext? execu
if (RunTestInitializeMethod(_classInstance, result, ref executionContext, timeoutTokenSource))
{
hasTestInitializePassed = true;
FixtureMethodRunner.RunOnContext(executionContext, () =>
var tcs = new TaskCompletionSource<object?>();
#pragma warning disable VSTHRD101 // Avoid unsupported async delegates
ExecutionContextHelpers.RunOnContext(executionContext, async () =>
{
TestMethod.InvokeAsSynchronousTask(_classInstance, arguments);
executionContext = ExecutionContext.Capture();
try
{
object? invokeResult = TestMethod.GetInvokeResult(_classInstance, arguments);
if (invokeResult is Task task)
{
await task;
}
else if (invokeResult is ValueTask valueTask)
{
await valueTask;
}

executionContext = ExecutionContext.Capture();
tcs.SetResult(null);
}
catch (Exception ex)
{
tcs.SetException(ex);
}
});
#pragma warning restore VSTHRD101 // Avoid unsupported async delegates

await tcs.Task;

result.Outcome = UTF.UnitTestOutcome.Passed;
}
Expand Down Expand Up @@ -694,7 +723,7 @@ private void RunTestCleanupMethod(TestResult result, ExecutionContext? execution
if (_classInstance is IAsyncDisposable classInstanceAsAsyncDisposable)
{
// If you implement IAsyncDisposable without calling the DisposeAsync this would result a resource leak.
FixtureMethodRunner.RunOnContext(executionContext, () =>
ExecutionContextHelpers.RunOnContext(executionContext, () =>
{
classInstanceAsAsyncDisposable.DisposeAsync().AsTask().Wait();
executionContext = ExecutionContext.Capture();
Expand All @@ -703,7 +732,7 @@ private void RunTestCleanupMethod(TestResult result, ExecutionContext? execution
#endif
if (_classInstance is IDisposable classInstanceAsDisposable)
{
FixtureMethodRunner.RunOnContext(executionContext, () =>
ExecutionContextHelpers.RunOnContext(executionContext, () =>
{
classInstanceAsDisposable.Dispose();
executionContext = ExecutionContext.Capture();
Expand Down Expand Up @@ -1058,7 +1087,7 @@ private bool SetTestContext(object classInstance, TestResult result)
/// <param name="executionContext">The execution context to execute the test method on.</param>
/// <returns>The result of execution.</returns>
[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Requirement is to handle all kinds of user exceptions and message appropriately.")]
private TestResult ExecuteInternalWithTimeout(object?[]? arguments, ExecutionContext? executionContext)
private async Task<TestResult> ExecuteInternalWithTimeoutAsync(object?[]? arguments, ExecutionContext? executionContext)
{
DebugEx.Assert(IsTimeoutSet, "Timeout should be set");

Expand All @@ -1082,7 +1111,7 @@ private TestResult ExecuteInternalWithTimeout(object?[]? arguments, ExecutionCon

try
{
return ExecuteInternal(arguments, executionContext, timeoutTokenSource);
return await ExecuteInternalAsync(arguments, executionContext, timeoutTokenSource);
}
catch (OperationCanceledException)
{
Expand Down Expand Up @@ -1133,7 +1162,7 @@ private TestResult ExecuteInternalWithTimeout(object?[]? arguments, ExecutionCon
else
{
// Cancel the token source as test has timed out
TestContext.Context.CancellationTokenSource.Cancel();
await TestContext.Context.CancellationTokenSource.CancelAsync();
}

TestResult timeoutResult = new() { Outcome = UTF.UnitTestOutcome.Timeout, TestFailureException = new TestFailedException(UTFUnitTestOutcome.Timeout, errorMessage) };
Expand All @@ -1152,7 +1181,14 @@ void ExecuteAsyncAction()
{
try
{
result = ExecuteInternal(arguments, executionContext, null);
// TODO: Avoid blocking.
// This used to always happen, but now is moved to the code path where there is a Timeout on the test method.
// The GetAwaiter().GetResult() call here can be a source of deadlocks, especially for UWP/WinUI.
// When the test method has `await`s with ConfigureAwait(true) (which is the default), the continuation is
// dispatched back to the SynchronizationContext which offloads the work to the UI thread.
// However, the GetAwaiter().GetResult() here will block the current thread which is also the UI thread.
// So, the continuations will not be able, thus this task never completes.
result = ExecuteInternalAsync(arguments, executionContext, null).GetAwaiter().GetResult();
}
catch (Exception ex)
{
Expand Down
44 changes: 22 additions & 22 deletions src/Adapter/MSTest.TestAdapter/Execution/TestMethodRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public TestMethodRunner(TestMethodInfo testMethodInfo, TestMethod testMethod, IT
/// Executes a test.
/// </summary>
/// <returns>The test results.</returns>
internal TestResult[] Execute(string initializationLogs, string initializationErrorLogs, string initializationTrace, string initializationTestContextMessages)
internal async Task<TestResult[]> ExecuteAsync(string initializationLogs, string initializationErrorLogs, string initializationTrace, string initializationTestContextMessages)
{
bool isSTATestClass = AttributeComparer.IsDerived<STATestClassAttribute>(_testMethodInfo.Parent.ClassAttribute);
bool isSTATestMethod = AttributeComparer.IsDerived<STATestMethodAttribute>(_testMethodInfo.Executor);
Expand All @@ -70,7 +70,7 @@ internal TestResult[] Execute(string initializationLogs, string initializationEr
if (isSTARequested && isWindowsOS && Thread.CurrentThread.GetApartmentState() != ApartmentState.STA)
{
TestResult[]? results = null;
Thread entryPointThread = new(() => results = SafeRunTestMethod(initializationLogs, initializationErrorLogs, initializationTrace, initializationTestContextMessages))
Thread entryPointThread = new(() => results = SafeRunTestMethodAsync(initializationLogs, initializationErrorLogs, initializationTrace, initializationTestContextMessages).GetAwaiter().GetResult())
{
Name = (isSTATestClass, isSTATestMethod) switch
{
Expand Down Expand Up @@ -102,17 +102,17 @@ internal TestResult[] Execute(string initializationLogs, string initializationEr
PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning(Resource.STAIsOnlySupportedOnWindowsWarning);
}

return SafeRunTestMethod(initializationLogs, initializationErrorLogs, initializationTrace, initializationTestContextMessages);
return await SafeRunTestMethodAsync(initializationLogs, initializationErrorLogs, initializationTrace, initializationTestContextMessages);
}

// Local functions
TestResult[] SafeRunTestMethod(string initializationLogs, string initializationErrorLogs, string initializationTrace, string initializationTestContextMessages)
async Task<TestResult[]> SafeRunTestMethodAsync(string initializationLogs, string initializationErrorLogs, string initializationTrace, string initializationTestContextMessages)
{
TestResult[]? result = null;

try
{
result = RunTestMethod();
result = await RunTestMethodAsync();
}
catch (TestFailedException ex)
{
Expand Down Expand Up @@ -155,7 +155,7 @@ TestResult[] SafeRunTestMethod(string initializationLogs, string initializationE
/// Runs the test method.
/// </summary>
/// <returns>The test results.</returns>
internal TestResult[] RunTestMethod()
internal async Task<TestResult[]> RunTestMethodAsync()
{
DebugEx.Assert(_test != null, "Test should not be null.");
DebugEx.Assert(_testMethodInfo.TestMethod != null, "Test method should not be null.");
Expand All @@ -177,21 +177,21 @@ internal TestResult[] RunTestMethod()
}

object?[]? data = DataSerializationHelper.Deserialize(_test.SerializedData);
TestResult[] testResults = ExecuteTestWithDataSource(null, data);
TestResult[] testResults = await ExecuteTestWithDataSourceAsync(null, data);
results.AddRange(testResults);
}
else if (TryExecuteDataSourceBasedTests(results))
else if (await TryExecuteDataSourceBasedTestsAsync(results))
{
isDataDriven = true;
}
else if (TryExecuteFoldedDataDrivenTests(results))
else if (await TryExecuteFoldedDataDrivenTestsAsync(results))
{
isDataDriven = true;
}
else
{
_testContext.SetDisplayName(_test.DisplayName);
TestResult[] testResults = ExecuteTest(_testMethodInfo);
TestResult[] testResults = await ExecuteTestAsync(_testMethodInfo);

foreach (TestResult testResult in testResults)
{
Expand Down Expand Up @@ -247,19 +247,19 @@ internal TestResult[] RunTestMethod()
return results.ToArray();
}

private bool TryExecuteDataSourceBasedTests(List<TestResult> results)
private async Task<bool> TryExecuteDataSourceBasedTestsAsync(List<TestResult> results)
{
DataSourceAttribute[] dataSourceAttribute = _testMethodInfo.GetAttributes<DataSourceAttribute>(false);
if (dataSourceAttribute is { Length: 1 })
{
ExecuteTestFromDataSourceAttribute(results);
await ExecuteTestFromDataSourceAttributeAsync(results);
return true;
}

return false;
}

private bool TryExecuteFoldedDataDrivenTests(List<TestResult> results)
private async Task<bool> TryExecuteFoldedDataDrivenTestsAsync(List<TestResult> results)
{
IEnumerable<UTF.ITestDataSource>? testDataSources = _testMethodInfo.GetAttributes<Attribute>(false)?.OfType<UTF.ITestDataSource>();
if (testDataSources?.Any() != true)
Expand Down Expand Up @@ -300,7 +300,7 @@ private bool TryExecuteFoldedDataDrivenTests(List<TestResult> results)
{
try
{
TestResult[] testResults = ExecuteTestWithDataSource(testDataSource, data);
TestResult[] testResults = await ExecuteTestWithDataSourceAsync(testDataSource, data);

results.AddRange(testResults);
}
Expand All @@ -314,7 +314,7 @@ private bool TryExecuteFoldedDataDrivenTests(List<TestResult> results)
return true;
}

private void ExecuteTestFromDataSourceAttribute(List<TestResult> results)
private async Task ExecuteTestFromDataSourceAttributeAsync(List<TestResult> results)
{
Stopwatch watch = new();
watch.Start();
Expand All @@ -339,7 +339,7 @@ private void ExecuteTestFromDataSourceAttribute(List<TestResult> results)

foreach (object dataRow in dataRows)
{
TestResult[] testResults = ExecuteTestWithDataRow(dataRow, rowIndex++);
TestResult[] testResults = await ExecuteTestWithDataRowAsync(dataRow, rowIndex++);
results.AddRange(testResults);
}
}
Expand All @@ -361,7 +361,7 @@ private void ExecuteTestFromDataSourceAttribute(List<TestResult> results)
}
}

private TestResult[] ExecuteTestWithDataSource(UTF.ITestDataSource? testDataSource, object?[]? data)
private async Task<TestResult[]> ExecuteTestWithDataSourceAsync(UTF.ITestDataSource? testDataSource, object?[]? data)
{
string? displayName = StringEx.IsNullOrWhiteSpace(_test.DisplayName)
? _test.Name
Expand Down Expand Up @@ -408,7 +408,7 @@ private TestResult[] ExecuteTestWithDataSource(UTF.ITestDataSource? testDataSour

TestResult[] testResults = ignoreFromTestDataRow is not null
? [TestResult.CreateIgnoredResult(ignoreFromTestDataRow)]
: ExecuteTest(_testMethodInfo);
: await ExecuteTestAsync(_testMethodInfo);

stopwatch.Stop();

Expand All @@ -425,7 +425,7 @@ private TestResult[] ExecuteTestWithDataSource(UTF.ITestDataSource? testDataSour
return testResults;
}

private TestResult[] ExecuteTestWithDataRow(object dataRow, int rowIndex)
private async Task<TestResult[]> ExecuteTestWithDataRowAsync(object dataRow, int rowIndex)
{
string displayName = string.Format(CultureInfo.CurrentCulture, Resource.DataDrivenResultDisplayName, _test.DisplayName, rowIndex);
Stopwatch? stopwatch = null;
Expand All @@ -435,7 +435,7 @@ private TestResult[] ExecuteTestWithDataRow(object dataRow, int rowIndex)
{
stopwatch = Stopwatch.StartNew();
_testContext.SetDataRow(dataRow);
testResults = ExecuteTest(_testMethodInfo);
testResults = await ExecuteTestAsync(_testMethodInfo);
}
finally
{
Expand All @@ -453,11 +453,11 @@ private TestResult[] ExecuteTestWithDataRow(object dataRow, int rowIndex)
return testResults;
}

private TestResult[] ExecuteTest(TestMethodInfo testMethodInfo)
private async Task<TestResult[]> ExecuteTestAsync(TestMethodInfo testMethodInfo)
{
try
{
return _testMethodInfo.Executor.Execute(testMethodInfo);
return await _testMethodInfo.Executor.ExecuteAsync(testMethodInfo);
}
catch (Exception ex)
{
Expand Down
Loading