Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -833,4 +833,104 @@ void M()
else
await VerifyAbsenceAsync(code);
}

[Fact]
public Task TestEventHandlerMethod_DoesNotChangeVoidToTask()
=> VerifyKeywordAsync("""
using System;

class C
{
public event EventHandler MyEvent;

public C()
{
MyEvent += OnMyEvent;
}

private void OnMyEvent(object sender, EventArgs e)
{
$$
}
}
""");

[Fact]
public Task TestEventHandlerMethod_WithDifferentEventType()
=> VerifyKeywordAsync("""
using System;

delegate void CustomEventHandler(object sender, EventArgs e);

class C
{
public event CustomEventHandler MyEvent;

public C()
{
MyEvent += HandleMyEvent;
}

private void HandleMyEvent(object sender, EventArgs e)
{
$$
}
}
""");

[Fact]
public Task TestEventHandlerMethod_InDifferentMethod()
=> VerifyKeywordAsync("""
using System;

class C
{
public event EventHandler MyEvent;

public void RegisterHandler()
{
MyEvent += OnMyEvent;
}

private void OnMyEvent(object sender, EventArgs e)
{
$$
}
}
""");

[Fact]
public Task TestEventHandlerMethod_WithMinusEquals()
=> VerifyKeywordAsync("""
using System;

class C
{
public event EventHandler MyEvent;

public void UnregisterHandler()
{
MyEvent -= OnMyEvent;
}

private void OnMyEvent(object sender, EventArgs e)
{
$$
}
}
""");

[Fact]
public Task TestNonEventHandlerMethod_ChangesVoidToTask()
=> VerifyKeywordAsync("""
using System;

class C
{
private void RegularMethod()
{
$$
}
}
""");
}
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,133 @@ public class C
await
}
}
", state.GetDocumentText())
End Using
End Function

<WpfFact>
Public Async Function AwaitCompletionDoesNotChangeReturnType_ForEventHandlerMethod() As Task
Using state = TestStateFactory.CreateCSharpTestState(
<Document><![CDATA[
using System;

public class C
{
public event EventHandler MyEvent;

public C()
{
MyEvent += OnMyEvent;
}

private void OnMyEvent(object sender, EventArgs e)
{
$$
}
}
]]>
</Document>)
state.SendTypeChars("aw")
Await state.AssertSelectedCompletionItem(displayText:="await", isHardSelected:=True)

state.SendTab()
Assert.Equal("
using System;

public class C
{
public event EventHandler MyEvent;

public C()
{
MyEvent += OnMyEvent;
}

private async void OnMyEvent(object sender, EventArgs e)
{
await
}
}
", state.GetDocumentText())
End Using
End Function

<WpfFact>
Public Async Function AwaitCompletionDoesNotChangeReturnType_ForEventHandlerWithMinusEquals() As Task
Using state = TestStateFactory.CreateCSharpTestState(
<Document><![CDATA[
using System;

public class C
{
public event EventHandler MyEvent;

public void UnregisterHandler()
{
MyEvent -= OnMyEvent;
}

private void OnMyEvent(object sender, EventArgs e)
{
$$
}
}
]]>
</Document>)
state.SendTypeChars("aw")
Await state.AssertSelectedCompletionItem(displayText:="await", isHardSelected:=True)

state.SendTab()
Assert.Equal("
using System;

public class C
{
public event EventHandler MyEvent;

public void UnregisterHandler()
{
MyEvent -= OnMyEvent;
}

private async void OnMyEvent(object sender, EventArgs e)
{
await
}
}
", state.GetDocumentText())
End Using
End Function

<WpfFact>
Public Async Function AwaitCompletionChangesVoidToTask_ForNonEventHandlerMethod() As Task
Using state = TestStateFactory.CreateCSharpTestState(
<Document><![CDATA[
using System.Threading.Tasks;

public class C
{
private void RegularMethod()
{
$$
}
}
]]>
</Document>)
state.SendTypeChars("aw")
Await state.AssertSelectedCompletionItem(displayText:="await", isHardSelected:=True)

state.SendTab()
Assert.Equal("
using System.Threading.Tasks;

public class C
{
private async Task RegularMethod()
{
await
}
}
", state.GetDocumentText())
End Using
End Function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ protected override int GetAsyncKeywordInsertionPosition(SyntaxNode declaration)
{
// `void => Task`
if (existingReturnType is PredefinedTypeSyntax { Keyword: (kind: SyntaxKind.VoidKeyword) })
{
// Don't change void to Task if this method is used as an event handler
if (IsMethodUsedAsEventHandler(declaration, semanticModel, cancellationToken))
return null;

return nameof(Task);
}

// Don't change the return type if we don't understand it, or it already seems task-like.
var taskLikeTypes = new KnownTaskTypes(semanticModel.Compilation);
Expand All @@ -94,6 +100,55 @@ protected override int GetAsyncKeywordInsertionPosition(SyntaxNode declaration)
}
}

private static bool IsMethodUsedAsEventHandler(SyntaxNode declaration, SemanticModel semanticModel, CancellationToken cancellationToken)
{
// Only check for methods (not lambdas or anonymous methods)
if (declaration is not MethodDeclarationSyntax methodDeclaration)
return false;

// Get the method symbol
var methodSymbol = semanticModel.GetDeclaredSymbol(methodDeclaration, cancellationToken);
if (methodSymbol is not IMethodSymbol method)
return false;

// Get the containing type
var containingType = method.ContainingType;
if (containingType is null)
return false;

// Get the syntax root for the containing type
var syntaxTree = methodDeclaration.SyntaxTree;
var root = syntaxTree.GetRoot(cancellationToken);

// Find all type declarations that could contain event hookups
var typeDeclarations = root.DescendantNodesAndSelf()
.OfType<TypeDeclarationSyntax>()
.Where(t => semanticModel.GetDeclaredSymbol(t, cancellationToken)?.Equals(containingType, SymbolEqualityComparer.Default) == true);

foreach (var typeDecl in typeDeclarations)
{
// Look for assignment expressions that could be event hookups (event += Method or event -= Method)
var assignmentExpressions = typeDecl.DescendantNodes()
.OfType<AssignmentExpressionSyntax>()
.Where(a => a.Kind() is SyntaxKind.AddAssignmentExpression or SyntaxKind.SubtractAssignmentExpression);

foreach (var assignment in assignmentExpressions)
{
// Check if the left side is an event
var leftSymbol = semanticModel.GetSymbolInfo(assignment.Left, cancellationToken).Symbol;
if (leftSymbol is not IEventSymbol)
continue;

// Check if the right side references our method
var rightSymbol = semanticModel.GetSymbolInfo(assignment.Right, cancellationToken).Symbol;
if (rightSymbol?.Equals(method, SymbolEqualityComparer.Default) == true)
return true;
}
}

return false;
}

protected override SyntaxNode? GetAsyncSupportingDeclaration(SyntaxToken leftToken, int position)
{
// In a case like
Expand Down
Loading