Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
Expand Down Expand Up @@ -67,54 +68,63 @@ protected override bool IsAsyncReturnType(ITypeSymbol type, KnownTaskTypes known
=> IsIAsyncEnumerableOrEnumerator(type, knownTypes) ||
knownTypes.IsTaskLike(type);

protected override SyntaxNode AddAsyncTokenAndFixReturnType(
protected override SyntaxNode FixMethodSignature(
bool addAsyncModifier,
bool keepVoid,
IMethodSymbol methodSymbol,
SyntaxNode node,
KnownTaskTypes knownTypes,
CancellationToken cancellationToken)
KnownTaskTypes knownTypes)
{
// We currently fix signature without adding 'async' modifier
// only for a partial definitions part of partial methods
Debug.Assert(addAsyncModifier || node is MethodDeclarationSyntax);

return node switch
{
MethodDeclarationSyntax method => FixMethod(keepVoid, methodSymbol, method, knownTypes, cancellationToken),
LocalFunctionStatementSyntax localFunction => FixLocalFunction(keepVoid, methodSymbol, localFunction, knownTypes, cancellationToken),
MethodDeclarationSyntax method => FixMethod(addAsyncModifier, keepVoid, methodSymbol, method, knownTypes),
LocalFunctionStatementSyntax localFunction => FixLocalFunction(keepVoid, methodSymbol, localFunction, knownTypes),
AnonymousFunctionExpressionSyntax anonymous => FixAnonymousFunction(anonymous),
_ => node,
};
}

private static MethodDeclarationSyntax FixMethod(
bool addAsyncModifier,
bool keepVoid,
IMethodSymbol methodSymbol,
MethodDeclarationSyntax method,
KnownTaskTypes knownTypes,
CancellationToken cancellationToken)
KnownTaskTypes knownTypes)
{
var (newModifiers, newReturnType) = AddAsyncModifierWithCorrectedTrivia(
method.Modifiers,
FixMethodReturnType(keepVoid, methodSymbol, method.ReturnType, knownTypes, cancellationToken));
return method.WithReturnType(newReturnType).WithModifiers(newModifiers);
var fixedReturnType = FixMethodReturnType(keepVoid, methodSymbol, method.ReturnType, knownTypes);

if (addAsyncModifier)
{
var (newModifiers, newReturnType) = AddAsyncModifierWithCorrectedTrivia(method.Modifiers, fixedReturnType);
return method.WithReturnType(newReturnType).WithModifiers(newModifiers);
}
else
{
return method.WithReturnType(fixedReturnType);
}
}

private static LocalFunctionStatementSyntax FixLocalFunction(
bool keepVoid,
IMethodSymbol methodSymbol,
LocalFunctionStatementSyntax localFunction,
KnownTaskTypes knownTypes,
CancellationToken cancellationToken)
KnownTaskTypes knownTypes)
{
var (newModifiers, newReturnType) = AddAsyncModifierWithCorrectedTrivia(
localFunction.Modifiers,
FixMethodReturnType(keepVoid, methodSymbol, localFunction.ReturnType, knownTypes, cancellationToken));
FixMethodReturnType(keepVoid, methodSymbol, localFunction.ReturnType, knownTypes));
return localFunction.WithReturnType(newReturnType).WithModifiers(newModifiers);
}

private static TypeSyntax FixMethodReturnType(
bool keepVoid,
IMethodSymbol methodSymbol,
TypeSyntax returnTypeSyntax,
KnownTaskTypes knownTypes,
CancellationToken cancellationToken)
KnownTaskTypes knownTypes)
{
var newReturnType = returnTypeSyntax.WithAdditionalAnnotations(Formatter.Annotation);

Expand All @@ -128,13 +138,13 @@ private static TypeSyntax FixMethodReturnType(
else
{
var returnType = methodSymbol.ReturnType;
if (IsIEnumerable(returnType, knownTypes) && IsIterator(methodSymbol, cancellationToken))
if (IsIEnumerable(returnType, knownTypes) && methodSymbol.IsIterator)
{
newReturnType = knownTypes.IAsyncEnumerableOfTType is null
? MakeGenericType(nameof(IAsyncEnumerable<>), methodSymbol.ReturnType)
: knownTypes.IAsyncEnumerableOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
}
else if (IsIEnumerator(returnType, knownTypes) && IsIterator(methodSymbol, cancellationToken))
else if (IsIEnumerator(returnType, knownTypes) && methodSymbol.IsIterator)
{
newReturnType = knownTypes.IAsyncEnumeratorOfTType is null
? MakeGenericType(nameof(IAsyncEnumerator<>), methodSymbol.ReturnType)
Expand Down Expand Up @@ -164,9 +174,6 @@ static TypeSyntax MakeGenericType(string type, ITypeSymbol typeArgumentFrom)
}
}

private static bool IsIterator(IMethodSymbol method, CancellationToken cancellationToken)
=> method.Locations.Any(static (loc, cancellationToken) => loc.FindNode(cancellationToken).ContainsYield(), cancellationToken);

private static bool IsIAsyncEnumerableOrEnumerator(ITypeSymbol returnType, KnownTaskTypes knownTypes)
=> returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumerableOfTType) ||
returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumeratorOfTType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ partial void M()

public partial class C
{
partial void MAsync();
partial Task MAsync();
}

public partial class C
Expand Down Expand Up @@ -1440,7 +1440,7 @@ public partial void M()

public partial class C
{
public partial void MAsync();
public partial Task MAsync();
}

public partial class C
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Rename;
using Microsoft.CodeAnalysis.Shared.Extensions;
Expand All @@ -24,8 +25,12 @@ internal abstract partial class AbstractMakeMethodAsynchronousCodeFixProvider :

protected abstract bool IsAsyncReturnType(ITypeSymbol type, KnownTaskTypes knownTypes);

protected abstract SyntaxNode AddAsyncTokenAndFixReturnType(
bool keepVoid, IMethodSymbol methodSymbol, SyntaxNode node, KnownTaskTypes knownTypes, CancellationToken cancellationToken);
protected abstract SyntaxNode FixMethodSignature(
bool addAsyncModifier,
bool keepVoid,
IMethodSymbol methodSymbol,
SyntaxNode node,
KnownTaskTypes knownTypes);

public override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer;

Expand Down Expand Up @@ -119,7 +124,7 @@ private async Task<Solution> FixNodeAsync(

return NeedsRename()
? await RenameThenAddAsyncTokenAsync(keepVoid, document, node, methodSymbol, knownTypes, cancellationToken).ConfigureAwait(false)
: await AddAsyncTokenAsync(keepVoid, document, methodSymbol, knownTypes, node, cancellationToken).ConfigureAwait(false);
: await FixRelatedSignaturesAsync(keepVoid, document, methodSymbol, knownTypes, node, cancellationToken).ConfigureAwait(false);

bool NeedsRename()
{
Expand Down Expand Up @@ -174,26 +179,39 @@ private async Task<Solution> RenameThenAddAsyncTokenAsync(
{
var semanticModel = await newDocument.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var newMethod = (IMethodSymbol)semanticModel.GetRequiredDeclaredSymbol(newNode, cancellationToken);
return await AddAsyncTokenAsync(keepVoid, newDocument, newMethod, knownTypes, newNode, cancellationToken).ConfigureAwait(false);
return await FixRelatedSignaturesAsync(keepVoid, newDocument, newMethod, knownTypes, newNode, cancellationToken).ConfigureAwait(false);
}

return newSolution;
}

private async Task<Solution> AddAsyncTokenAsync(
private async Task<Solution> FixRelatedSignaturesAsync(
bool keepVoid,
Document document,
IMethodSymbol methodSymbol,
KnownTaskTypes knownTypes,
SyntaxNode node,
CancellationToken cancellationToken)
{
var newNode = AddAsyncTokenAndFixReturnType(keepVoid, methodSymbol, node, knownTypes, cancellationToken);
var newNode = FixMethodSignature(addAsyncModifier: true, keepVoid, methodSymbol, node, knownTypes);

var solution = document.Project.Solution;
var solutionEditor = new SolutionEditor(solution);
var mainDocumentEditor = await solutionEditor.GetDocumentEditorAsync(document.Id, cancellationToken).ConfigureAwait(false);

var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var newRoot = root.ReplaceNode(node, newNode);
mainDocumentEditor.ReplaceNode(node, newNode);

if (!keepVoid && methodSymbol.PartialDefinitionPart is { Locations: [{ } partialDefinitionLocation] })
{
var partialDefinitionNode = partialDefinitionLocation.FindNode(cancellationToken);
var fixedPartialDefinitionNode = FixMethodSignature(addAsyncModifier: false, keepVoid, methodSymbol, partialDefinitionNode, knownTypes);

var partialDefinitionDocument = solution.GetDocument(partialDefinitionNode.SyntaxTree);
Contract.ThrowIfNull(partialDefinitionDocument);
var partialDefinitionDocumentEditor = await solutionEditor.GetDocumentEditorAsync(partialDefinitionDocument.Id, cancellationToken).ConfigureAwait(false);
partialDefinitionDocumentEditor.ReplaceNode(partialDefinitionNode, fixedPartialDefinitionNode);
}

var newDocument = document.WithSyntaxRoot(newRoot);
return newDocument.Project.Solution;
return solutionEditor.GetChangedSolution();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,16 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.MakeMethodAsynchronous
Return knownTypes.IsTaskLike(type)
End Function

Protected Overrides Function AddAsyncTokenAndFixReturnType(
Protected Overrides Function FixMethodSignature(
addAsyncModifier As Boolean,
keepVoid As Boolean,
methodSymbolOpt As IMethodSymbol,
node As SyntaxNode,
knownTypes As KnownTaskTypes,
cancellationToken As CancellationToken) As SyntaxNode
knownTypes As KnownTaskTypes) As SyntaxNode

' This flag can only be false when updating partial definition method signature.
' Since partial methods cannot be async in VB, it cannot be false here
Debug.Assert(addAsyncModifier)

If node.IsKind(SyntaxKind.SingleLineSubLambdaExpression) OrElse
node.IsKind(SyntaxKind.SingleLineFunctionLambdaExpression) Then
Expand Down
Loading