Skip to content

Commit

Permalink
Merge pull request #1252 from AArnott/analyzerFixes
Browse files Browse the repository at this point in the history
Fix VSTHRD003 to allow for awaiting more Task properties
  • Loading branch information
AArnott authored Oct 25, 2023
2 parents 4b595ea + b561132 commit cf86d12
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ internal static ContainingFunctionData GetContainingFunction(CSharpSyntaxNode? s
return new ContainingFunctionData(simpleLambda, simpleLambda.AsyncKeyword != default(SyntaxToken), SyntaxFactory.ParameterList().AddParameters(simpleLambda.Parameter), simpleLambda.Body, simpleLambda.WithBody);
}

if (syntaxNode is LocalFunctionStatementSyntax localFunc)
{
return new ContainingFunctionData(localFunc, localFunc.Modifiers.Any(SyntaxKind.AsyncKeyword), localFunc.ParameterList, (CSharpSyntaxNode?)localFunc.ExpressionBody ?? localFunc.Body, localFunc.WithBody);
}

if (syntaxNode is AnonymousMethodExpressionSyntax anonymousMethod)
{
return new ContainingFunctionData(anonymousMethod, anonymousMethod.AsyncKeyword != default(SyntaxToken), anonymousMethod.ParameterList, anonymousMethod.Body, anonymousMethod.WithBody);
Expand Down Expand Up @@ -125,7 +130,7 @@ internal static bool IsOnLeftHandOfAssignment(SyntaxNode syntaxNode)
return false;
}

internal static bool IsAssignedWithin(SyntaxNode container, SemanticModel semanticModel, ISymbol variable, CancellationToken cancellationToken)
internal static IEnumerable<ExpressionSyntax> FindAssignedValuesWithin(SyntaxNode container, SemanticModel semanticModel, ISymbol variable, CancellationToken cancellationToken)
{
if (semanticModel is null)
{
Expand All @@ -139,22 +144,36 @@ internal static bool IsAssignedWithin(SyntaxNode container, SemanticModel semant

if (container is null)
{
return false;
yield break;
}

foreach (SyntaxNode? node in container.DescendantNodesAndSelf(n => !(n is AnonymousFunctionExpressionSyntax)))
foreach (SyntaxNode? node in container.DescendantNodesAndSelf(n => !(n is AnonymousFunctionExpressionSyntax or LocalFunctionStatementSyntax)))
{
cancellationToken.ThrowIfCancellationRequested();
if (node is AssignmentExpressionSyntax assignment)
{
ISymbol? assignedSymbol = semanticModel.GetSymbolInfo(assignment.Left, cancellationToken).Symbol;
if (variable.Equals(assignedSymbol, SymbolEqualityComparer.Default))
{
return true;
yield return assignment.Right;
}
}
}

return false;
if (node is LocalDeclarationStatementSyntax localDeclarationStatement)
{
foreach (VariableDeclaratorSyntax localDeclVar in localDeclarationStatement.Declaration.Variables)
{
if (localDeclVar.Initializer is not null)
{
ISymbol? assignedSymbol = semanticModel.GetDeclaredSymbol(localDeclVar, cancellationToken);
if (variable.Equals(assignedSymbol, SymbolEqualityComparer.Default))
{
yield return localDeclVar.Initializer.Value;
}
}
}
}
}
}

internal static MemberAccessExpressionSyntax MemberAccess(IReadOnlyList<string> qualifiers, SimpleNameSyntax simpleName)
Expand Down Expand Up @@ -277,7 +296,7 @@ internal override bool IsAsyncMethod(SyntaxNode syntaxNode)

internal readonly struct ContainingFunctionData
{
internal ContainingFunctionData(CSharpSyntaxNode function, bool isAsync, ParameterListSyntax? parameterList, CSharpSyntaxNode? blockOrExpression, Func<CSharpSyntaxNode, CSharpSyntaxNode> bodyReplacement)
internal ContainingFunctionData(CSharpSyntaxNode function, bool isAsync, ParameterListSyntax? parameterList, CSharpSyntaxNode? blockOrExpression, Func<BlockSyntax, CSharpSyntaxNode> bodyReplacement)
{
this.Function = function;
this.IsAsync = isAsync;
Expand All @@ -294,6 +313,6 @@ internal ContainingFunctionData(CSharpSyntaxNode function, bool isAsync, Paramet

internal CSharpSyntaxNode? BlockOrExpression { get; }

internal Func<CSharpSyntaxNode, CSharpSyntaxNode> BodyReplacement { get; }
internal Func<BlockSyntax, CSharpSyntaxNode> BodyReplacement { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,25 +155,28 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)
return null;
}

SymbolInfo symbolToConsider = semanticModel.GetSymbolInfo(expressionSyntax, cancellationToken);
ExpressionSyntax focusedExpression = expressionSyntax;
SymbolInfo symbolToConsider = semanticModel.GetSymbolInfo(focusedExpression, cancellationToken);
if (CommonInterest.TaskConfigureAwait.Any(configureAwait => configureAwait.IsMatch(symbolToConsider.Symbol)))
{
// If the invocation is wrapped inside parentheses then drill down to get the invocation.
while (expressionSyntax is ParenthesizedExpressionSyntax parenthesizedExprSyntax)
while (focusedExpression is ParenthesizedExpressionSyntax parenthesizedExprSyntax)
{
expressionSyntax = parenthesizedExprSyntax.Expression;
focusedExpression = parenthesizedExprSyntax.Expression;
}

Debug.Assert(expressionSyntax is InvocationExpressionSyntax, "expressionSyntax should be an invocation");
Debug.Assert(focusedExpression is InvocationExpressionSyntax, "focusedExpression should be an invocation");

if (((InvocationExpressionSyntax)expressionSyntax).Expression is MemberAccessExpressionSyntax memberAccessExpression)
if (((InvocationExpressionSyntax)focusedExpression).Expression is MemberAccessExpressionSyntax memberAccessExpression)
{
focusedExpression = memberAccessExpression.Expression;
symbolToConsider = semanticModel.GetSymbolInfo(memberAccessExpression.Expression, cancellationToken);
}
}

ITypeSymbol symbolType;
bool dataflowAnalysisCompatibleVariable = false;
CSharpUtils.ContainingFunctionData? containingFunc = null;
switch (symbolToConsider.Symbol)
{
case ILocalSymbol localSymbol:
Expand All @@ -182,6 +185,29 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)
break;
case IPropertySymbol propertySymbol when !IsSymbolAlwaysOkToAwait(propertySymbol):
symbolType = propertySymbol.Type;

if (focusedExpression is MemberAccessExpressionSyntax memberAccessExpression)
{
// Do not report a warning if the task is a member of an object that was returned from an invocation made in this method.
if (memberAccessExpression.Expression is InvocationExpressionSyntax)
{
return null;
}

// Do not report a warning if the task is a member of an object that was created in this method.
if (memberAccessExpression.Expression is IdentifierNameSyntax identifier &&
semanticModel.GetSymbolInfo(identifier, cancellationToken).Symbol is ILocalSymbol local)
{
// Search for assignments to the local and see if it was to a new object.
containingFunc ??= CSharpUtils.GetContainingFunction(focusedExpression);
if (containingFunc.Value.BlockOrExpression is not null &&
CSharpUtils.FindAssignedValuesWithin(containingFunc.Value.BlockOrExpression, semanticModel, local, cancellationToken).Any(v => v is ObjectCreationExpressionSyntax))
{
return null;
}
}
}

break;
case IParameterSymbol parameterSymbol:
symbolType = parameterSymbol.Type;
Expand Down Expand Up @@ -247,7 +273,7 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)

break;
case IMethodSymbol methodSymbol:
if (Utils.IsTask(methodSymbol.ReturnType) && expressionSyntax is InvocationExpressionSyntax invocationExpressionSyntax)
if (Utils.IsTask(methodSymbol.ReturnType) && focusedExpression is InvocationExpressionSyntax invocationExpressionSyntax)
{
// Consider all arguments
IEnumerable<ExpressionSyntax>? expressionsToConsider = invocationExpressionSyntax.ArgumentList.Arguments.Select(a => a.Expression);
Expand Down Expand Up @@ -275,8 +301,8 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)
}

// Report warning if the task was not initialized within the current delegate or lambda expression
CSharpUtils.ContainingFunctionData containingFunc = CSharpUtils.GetContainingFunction(expressionSyntax);
if (containingFunc.BlockOrExpression is BlockSyntax delegateBlock)
containingFunc ??= CSharpUtils.GetContainingFunction(focusedExpression);
if (containingFunc.Value.BlockOrExpression is BlockSyntax delegateBlock)
{
if (dataflowAnalysisCompatibleVariable)
{
Expand All @@ -285,9 +311,9 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)

// When possible (await is direct child of the block and not a field), execute data flow analysis by passing first and last statement to capture only what happens before the await
// Check if the await is direct child of the code block (first parent is ExpressionStantement, second parent is the block itself)
if (delegateBlock.Equals(expressionSyntax.Parent?.Parent?.Parent))
if (delegateBlock.Equals(focusedExpression.Parent?.Parent?.Parent))
{
dataFlowAnalysis = semanticModel.AnalyzeDataFlow(delegateBlock.ChildNodes().First(), expressionSyntax.Parent.Parent);
dataFlowAnalysis = semanticModel.AnalyzeDataFlow(delegateBlock.ChildNodes().First(), focusedExpression.Parent.Parent);
}
else
{
Expand All @@ -297,22 +323,22 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)

if (dataFlowAnalysis?.WrittenInside.Contains(symbolToConsider.Symbol) is false)
{
return Diagnostic.Create(Descriptor, expressionSyntax.GetLocation());
return Diagnostic.Create(Descriptor, focusedExpression.GetLocation());
}
}
else
{
// Do the best we can searching for assignment statements.
if (!CSharpUtils.IsAssignedWithin(containingFunc.BlockOrExpression, semanticModel, symbolToConsider.Symbol, cancellationToken))
if (!CSharpUtils.FindAssignedValuesWithin(containingFunc.Value.BlockOrExpression, semanticModel, symbolToConsider.Symbol, cancellationToken).Any())
{
return Diagnostic.Create(Descriptor, expressionSyntax.GetLocation());
return Diagnostic.Create(Descriptor, focusedExpression.GetLocation());
}
}
}
else
{
// It's not a block, it's just a lambda expression, so the variable must be external.
return Diagnostic.Create(Descriptor, expressionSyntax.GetLocation());
return Diagnostic.Create(Descriptor, focusedExpression.GetLocation());
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,11 @@ class Tests
public async Task AwaitAndGetResult()
{{
await task.ConfigureAwait({(continueOnCapturedContext ? "true" : "false")});
await [|task|].ConfigureAwait({(continueOnCapturedContext ? "true" : "false")});
}}
}}
";
DiagnosticResult expected = this.CreateDiagnostic(10, 15, 21 + continueOnCapturedContext.ToString().Length);
await CSVerify.VerifyAnalyzerAsync(test, expected);
await CSVerify.VerifyAnalyzerAsync(test);
}

[Theory]
Expand All @@ -223,12 +222,11 @@ class Tests
public async Task<int> AwaitAndGetResult()
{{
return await task.ConfigureAwait({(continueOnCapturedContext ? "true" : "false")});
return await [|task|].ConfigureAwait({(continueOnCapturedContext ? "true" : "false")});
}}
}}
";
DiagnosticResult expected = this.CreateDiagnostic(10, 22, 21 + continueOnCapturedContext.ToString().Length);
await CSVerify.VerifyAnalyzerAsync(test, expected);
await CSVerify.VerifyAnalyzerAsync(test);
}

[Fact]
Expand All @@ -244,12 +242,11 @@ class Tests
public async Task AwaitAndGetResult()
{
await task.ConfigureAwaitRunInline();
await [|task|].ConfigureAwaitRunInline();
}
}
";
DiagnosticResult expected = this.CreateDiagnostic(11, 15, 30);
await CSVerify.VerifyAnalyzerAsync(test, expected);
await CSVerify.VerifyAnalyzerAsync(test);
}

[Fact]
Expand All @@ -265,12 +262,11 @@ class Tests
public async Task<int> AwaitAndGetResult()
{
return await task.ConfigureAwaitRunInline();
return await [|task|].ConfigureAwaitRunInline();
}
}
";
DiagnosticResult expected = this.CreateDiagnostic(11, 22, 30);
await CSVerify.VerifyAnalyzerAsync(test, expected);
await CSVerify.VerifyAnalyzerAsync(test);
}

[Fact]
Expand Down Expand Up @@ -1280,6 +1276,64 @@ async Task GetTask()
await CSVerify.VerifyAnalyzerAsync(test);
}

[Fact]
public async Task DoNotReportWarningWhenAwaitingTaskPropertyOfObjectCreatedInContext()
{
var test = @"
using System.Threading.Tasks;
class Tests
{
private Task MyTaskProperty { get; set; }
static async Task GetTask()
{
// our own property.
var obj = new Tests();
await obj.MyTaskProperty;
// local with initializer
var tcs = new TaskCompletionSource<int>();
await tcs.Task;
// Assign later
TaskCompletionSource<int> tcs2;
tcs2 = new TaskCompletionSource<int>();
await tcs2.Task;
// Assigned, but not to a newly created object.
TaskCompletionSource<int> tcs3 = tcs2;
await [|tcs3.Task|];
}
}
";
await CSVerify.VerifyAnalyzerAsync(test);
}

/// <summary>
/// This is important to allow folks to return jtf.RunAsync(...).Task from a method.
/// </summary>
[Fact]
public async Task DoNotReportWarningWhenAwaitingTaskPropertyOfObjectReturnedFromMethod()
{
var test = @"
using System.Threading.Tasks;
class Tests
{
private Task MyTaskProperty { get; set; }
static Tests NewTests() => new Tests();
static async Task GetTask()
{
await NewTests().MyTaskProperty;
}
}
";
await CSVerify.VerifyAnalyzerAsync(test);
}

[Fact]
public async Task ReportWarningWhenAwaitingTaskPropertyThatWasNotSetInContext()
{
Expand All @@ -1293,6 +1347,7 @@ class Tests
async Task GetTask()
{
await [|this.MyTaskProperty|];
await [|MyTaskProperty|];
}
}
";
Expand Down

0 comments on commit cf86d12

Please sign in to comment.