Skip to content

Commit 4d442ba

Browse files
Support selector pattern simplification in linq queries (#76086)
2 parents 540bf4c + a3b12cd commit 4d442ba

File tree

3 files changed

+132
-80
lines changed

3 files changed

+132
-80
lines changed

src/Analyzers/CSharp/Analyzers/SimplifyLinqExpression/CSharpSimplifyLinqExpressionDiagnosticAnalyzer.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,5 @@ internal sealed class CSharpSimplifyLinqExpressionDiagnosticAnalyzer : AbstractS
2121
protected override IInvocationOperation? TryGetNextInvocationInChain(IInvocationOperation invocation)
2222
// In C#, extension methods contain the methods they are being called from in the `this` parameter
2323
// So in the case of A().ExtensionB() to get to ExtensionB from A we do the following:
24-
=> invocation.Parent is IArgumentOperation argument &&
25-
argument.Parent is IInvocationOperation nextInvocation
26-
? nextInvocation
27-
: null;
24+
=> invocation.Parent is IArgumentOperation { Parent: IInvocationOperation nextInvocation } ? nextInvocation : null;
2825
}

src/Analyzers/CSharp/Tests/SimplifyLinqExpression/CSharpSimplifyLinqExpressionTests.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,4 +640,38 @@ public void Test()
640640
"""
641641
}.RunAsync();
642642
}
643+
644+
[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/75845")]
645+
public static async Task TestSelectSum()
646+
{
647+
await new VerifyCS.Test
648+
{
649+
TestCode = """
650+
using System;
651+
using System.Linq;
652+
using System.Collections.Generic;
653+
654+
class C
655+
{
656+
public void Test(int[] numbers)
657+
{
658+
var sumOfSquares = [|numbers.Select(n => n * n).Sum()|];
659+
}
660+
}
661+
""",
662+
FixedCode = """
663+
using System;
664+
using System.Linq;
665+
using System.Collections.Generic;
666+
667+
class C
668+
{
669+
public void Test(int[] numbers)
670+
{
671+
var sumOfSquares = numbers.Sum(n => n * n);
672+
}
673+
}
674+
"""
675+
}.RunAsync();
676+
}
643677
}

src/Analyzers/Core/Analyzers/SimplifyLinqExpression/AbstractSimplifyLinqExpressionDiagnosticAnalyzer.cs

Lines changed: 97 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Immutable;
67
using System.Diagnostics.CodeAnalysis;
78
using System.Linq;
@@ -13,33 +14,38 @@
1314

1415
namespace Microsoft.CodeAnalysis.SimplifyLinqExpression;
1516

16-
internal abstract class AbstractSimplifyLinqExpressionDiagnosticAnalyzer<TInvocationExpressionSyntax, TMemberAccessExpressionSyntax> : AbstractBuiltInCodeStyleDiagnosticAnalyzer
17+
internal abstract class AbstractSimplifyLinqExpressionDiagnosticAnalyzer<TInvocationExpressionSyntax, TMemberAccessExpressionSyntax>()
18+
: AbstractBuiltInCodeStyleDiagnosticAnalyzer(
19+
IDEDiagnosticIds.SimplifyLinqExpressionDiagnosticId,
20+
EnforceOnBuildValues.SimplifyLinqExpression,
21+
option: null,
22+
title: new LocalizableResourceString(nameof(AnalyzersResources.Simplify_LINQ_expression), AnalyzersResources.ResourceManager, typeof(AnalyzersResources)))
1723
where TInvocationExpressionSyntax : SyntaxNode
1824
where TMemberAccessExpressionSyntax : SyntaxNode
1925
{
20-
private static readonly ImmutableHashSet<string> s_nonEnumerableReturningLinqMethodNames =
21-
ImmutableHashSet.Create(
26+
private static readonly ImmutableHashSet<string> s_nonEnumerableReturningLinqPredicateMethodNames =
27+
[
2228
nameof(Enumerable.First),
2329
nameof(Enumerable.Last),
2430
nameof(Enumerable.Single),
2531
nameof(Enumerable.Any),
2632
nameof(Enumerable.Count),
2733
nameof(Enumerable.SingleOrDefault),
2834
nameof(Enumerable.FirstOrDefault),
29-
nameof(Enumerable.LastOrDefault));
35+
nameof(Enumerable.LastOrDefault),
36+
];
37+
private static readonly ImmutableHashSet<string> s_nonEnumerableReturningLinqSelectorMethodNames =
38+
[
39+
nameof(Enumerable.Average),
40+
nameof(Enumerable.Sum),
41+
nameof(Enumerable.Min),
42+
nameof(Enumerable.Max),
43+
];
3044

3145
protected abstract ISyntaxFacts SyntaxFacts { get; }
3246

3347
protected abstract bool ConflictsWithMemberByNameOnly { get; }
3448

35-
public AbstractSimplifyLinqExpressionDiagnosticAnalyzer()
36-
: base(IDEDiagnosticIds.SimplifyLinqExpressionDiagnosticId,
37-
EnforceOnBuildValues.SimplifyLinqExpression,
38-
option: null,
39-
title: new LocalizableResourceString(nameof(AnalyzersResources.Simplify_LINQ_expression), AnalyzersResources.ResourceManager, typeof(AnalyzersResources)))
40-
{
41-
}
42-
4349
protected abstract IInvocationOperation? TryGetNextInvocationInChain(IInvocationOperation invocation);
4450

4551
public override DiagnosticAnalyzerCategory GetAnalyzerCategory()
@@ -50,18 +56,13 @@ protected override void InitializeWorker(AnalysisContext context)
5056

5157
private void OnCompilationStart(CompilationStartAnalysisContext context)
5258
{
53-
if (!TryGetEnumerableTypeSymbol(context.Compilation, out var enumerableType))
54-
return;
55-
56-
if (!TryGetLinqWhereExtensionMethod(enumerableType, out var whereMethodSymbol))
57-
return;
58-
59-
if (!TryGetLinqMethodsThatDoNotReturnEnumerables(enumerableType, out var linqMethodSymbols))
60-
return;
61-
62-
context.RegisterOperationAction(
63-
context => AnalyzeInvocationOperation(context, enumerableType, whereMethodSymbol, linqMethodSymbols),
64-
OperationKind.Invocation);
59+
if (TryGetEnumerableTypeSymbol(context.Compilation, out var enumerableType) &&
60+
TryGetLinqWhereExtensionMethod(enumerableType, out var whereMethodSymbol) &&
61+
TryGetLinqSelectExtensionMethod(enumerableType, out var selectMethodSymbol) &&
62+
TryGetLinqMethodsThatDoNotReturnEnumerables(enumerableType, out var linqMethods))
63+
{
64+
context.RegisterOperationAction(AnalyzeInvocationOperation, OperationKind.Invocation);
65+
}
6566

6667
return;
6768

@@ -71,31 +72,40 @@ static bool TryGetEnumerableTypeSymbol(Compilation compilation, [NotNullWhen(tru
7172
return enumerableType is not null;
7273
}
7374

74-
static bool TryGetLinqWhereExtensionMethod(INamedTypeSymbol enumerableType, [NotNullWhen(true)] out IMethodSymbol? whereMethod)
75+
static bool TryGetLinqWhereExtensionMethod(INamedTypeSymbol enumerableType, [NotNullWhen(true)] out IMethodSymbol? linqMethod)
76+
=> TryGetLinqExtensionMethod(enumerableType, nameof(Enumerable.Where), out linqMethod);
77+
78+
static bool TryGetLinqSelectExtensionMethod(INamedTypeSymbol enumerableType, [NotNullWhen(true)] out IMethodSymbol? linqMethod)
79+
=> TryGetLinqExtensionMethod(enumerableType, nameof(Enumerable.Select), out linqMethod);
80+
81+
static bool TryGetLinqExtensionMethod(INamedTypeSymbol enumerableType, string name, [NotNullWhen(true)] out IMethodSymbol? linqMethod)
7582
{
76-
foreach (var whereMethodSymbol in enumerableType.GetMembers(nameof(Enumerable.Where)).OfType<IMethodSymbol>())
83+
foreach (var linqMethodSymbol in enumerableType.GetMembers(name).OfType<IMethodSymbol>())
7784
{
78-
var parameters = whereMethodSymbol.Parameters;
79-
80-
if (parameters is [_, { Type: INamedTypeSymbol { Arity: 2 } }])
85+
if (linqMethodSymbol.Parameters is [_, { Type: INamedTypeSymbol { Arity: 2 } }])
8186
{
82-
// This is the where overload that does not take and index (i.e. Where(source, Func<T, bool>) vs Where(source, Func<T, int, bool>))
83-
whereMethod = whereMethodSymbol;
87+
// This is the Where/Select overload that does not take and index (i.e. Where(source, Func<T, bool>)
88+
// vs Where(source, Func<T, int, bool>))
89+
linqMethod = linqMethodSymbol;
8490
return true;
8591
}
8692
}
8793

88-
whereMethod = null;
94+
linqMethod = null;
8995
return false;
9096
}
9197

9298
static bool TryGetLinqMethodsThatDoNotReturnEnumerables(INamedTypeSymbol enumerableType, out ImmutableArray<IMethodSymbol> linqMethods)
9399
{
94100
using var _ = ArrayBuilder<IMethodSymbol>.GetInstance(out var linqMethodSymbolsBuilder);
101+
95102
foreach (var method in enumerableType.GetMembers().OfType<IMethodSymbol>())
96103
{
97-
if (s_nonEnumerableReturningLinqMethodNames.Contains(method.Name) &&
98-
method.Parameters is { Length: 1 })
104+
if (method.Parameters.Length != 1)
105+
continue;
106+
107+
if (s_nonEnumerableReturningLinqPredicateMethodNames.Contains(method.Name) ||
108+
s_nonEnumerableReturningLinqSelectorMethodNames.Contains(method.Name))
99109
{
100110
linqMethodSymbolsBuilder.AddRange(method);
101111
}
@@ -104,65 +114,76 @@ static bool TryGetLinqMethodsThatDoNotReturnEnumerables(INamedTypeSymbol enumera
104114
linqMethods = linqMethodSymbolsBuilder.ToImmutable();
105115
return linqMethods.Any();
106116
}
107-
}
108-
109-
public void AnalyzeInvocationOperation(OperationAnalysisContext context, INamedTypeSymbol enumerableType, IMethodSymbol whereMethod, ImmutableArray<IMethodSymbol> linqMethods)
110-
{
111-
if (ShouldSkipAnalysis(context, notification: null))
112-
return;
113117

114-
if (context.Operation.Syntax.GetDiagnostics().Any(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error))
118+
void AnalyzeInvocationOperation(OperationAnalysisContext context)
115119
{
120+
if (ShouldSkipAnalysis(context, notification: null))
121+
return;
122+
116123
// Do not analyze linq methods that contain diagnostics.
117-
return;
118-
}
124+
if (context.Operation.Syntax.GetDiagnostics().Any(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error))
125+
return;
119126

120-
if (context.Operation is not IInvocationOperation invocation ||
121-
!IsWhereLinqMethod(invocation))
122-
{
123-
// we only care about Where methods on linq expressions
124-
return;
125-
}
127+
// we only care about Where/Select invocation methods on linq expressions
126128

127-
if (TryGetNextInvocationInChain(invocation) is not IInvocationOperation nextInvocation ||
128-
!IsInvocationNonEnumerableReturningLinqMethod(nextInvocation))
129-
{
130-
// Invocation is not part of a chain of invocations (i.e. Where(x => x is not null).First())
131-
return;
132-
}
129+
if (context.Operation is not IInvocationOperation invocation)
130+
return;
133131

134-
if (TryGetSymbolOfMemberAccess(invocation) is not INamedTypeSymbol targetTypeSymbol ||
135-
TryGetMethodName(nextInvocation) is not string name)
136-
{
137-
return;
138-
}
132+
var isWhereMethod = IsWhereLinqMethod(invocation);
133+
var isSelectMethod = IsSelectLinqMethod(invocation);
134+
if (!isWhereMethod && !isSelectMethod)
135+
return;
139136

140-
// Do not offer to transpose if there is already a method on the collection named the same as the linq extension
141-
// method. This would cause us to call the instance method after the transformation, not the extension method.
142-
if (!targetTypeSymbol.Equals(enumerableType, SymbolEqualityComparer.Default) &&
143-
targetTypeSymbol.MemberNames.Contains(name))
144-
{
145-
// VB conflicts if any member has the same name (like a Count property vs Count extension method).
146-
if (this.ConflictsWithMemberByNameOnly)
137+
if (TryGetNextInvocationInChain(invocation) is not IInvocationOperation nextInvocation ||
138+
!IsInvocationNonEnumerableReturningLinqMethod(nextInvocation))
139+
{
140+
// Invocation is not part of a chain of invocations (i.e. Where(x => x is not null).First())
147141
return;
142+
}
148143

149-
// C# conflicts only if it is a method as well. So a Count property will not conflict with a Count
150-
// extension method.
151-
if (targetTypeSymbol.GetMembers(name).Any(m => m is IMethodSymbol))
144+
if (TryGetSymbolOfMemberAccess(invocation) is not ITypeSymbol targetTypeSymbol ||
145+
TryGetMethodName(nextInvocation) is not string name)
146+
{
152147
return;
153-
}
148+
}
154149

155-
context.ReportDiagnostic(Diagnostic.Create(Descriptor, nextInvocation.Syntax.GetLocation()));
150+
if (isWhereMethod && !s_nonEnumerableReturningLinqPredicateMethodNames.Contains(name))
151+
return;
156152

157-
return;
153+
if (isSelectMethod && !s_nonEnumerableReturningLinqSelectorMethodNames.Contains(name))
154+
return;
155+
156+
// Do not offer to transpose if there is already a method on the collection named the same as the linq extension
157+
// method. This would cause us to call the instance method after the transformation, not the extension method.
158+
if (!targetTypeSymbol.Equals(enumerableType, SymbolEqualityComparer.Default))
159+
{
160+
var members = targetTypeSymbol.GetMembers(name);
161+
if (members.Length > 0)
162+
{
163+
// VB conflicts if any member has the same name (like a Count property vs Count extension method).
164+
if (this.ConflictsWithMemberByNameOnly)
165+
return;
166+
167+
// C# conflicts only if it is a method as well. So a Count property will not conflict with a Count
168+
// extension method.
169+
if (members.Any(m => m is IMethodSymbol))
170+
return;
171+
}
172+
}
173+
174+
context.ReportDiagnostic(Diagnostic.Create(Descriptor, nextInvocation.Syntax.GetLocation()));
175+
}
158176

159177
bool IsWhereLinqMethod(IInvocationOperation invocation)
160-
=> whereMethod.Equals(invocation.TargetMethod.ReducedFrom ?? invocation.TargetMethod.OriginalDefinition, SymbolEqualityComparer.Default);
178+
=> whereMethodSymbol.Equals(invocation.TargetMethod.ReducedFrom ?? invocation.TargetMethod.OriginalDefinition, SymbolEqualityComparer.Default);
179+
180+
bool IsSelectLinqMethod(IInvocationOperation invocation)
181+
=> selectMethodSymbol.Equals(invocation.TargetMethod.ReducedFrom ?? invocation.TargetMethod.OriginalDefinition, SymbolEqualityComparer.Default);
161182

162183
bool IsInvocationNonEnumerableReturningLinqMethod(IInvocationOperation invocation)
163184
=> linqMethods.Any(static (m, invocation) => m.Equals(invocation.TargetMethod.ReducedFrom ?? invocation.TargetMethod.OriginalDefinition, SymbolEqualityComparer.Default), invocation);
164185

165-
INamedTypeSymbol? TryGetSymbolOfMemberAccess(IInvocationOperation invocation)
186+
ITypeSymbol? TryGetSymbolOfMemberAccess(IInvocationOperation invocation)
166187
{
167188
if (invocation.Syntax is not TInvocationExpressionSyntax invocationNode ||
168189
SyntaxFacts.GetExpressionOfInvocationExpression(invocationNode) is not TMemberAccessExpressionSyntax memberAccess ||
@@ -171,7 +192,7 @@ bool IsInvocationNonEnumerableReturningLinqMethod(IInvocationOperation invocatio
171192
return null;
172193
}
173194

174-
return invocation.SemanticModel?.GetTypeInfo(expression).Type as INamedTypeSymbol;
195+
return invocation.SemanticModel?.GetTypeInfo(expression).Type;
175196
}
176197

177198
string? TryGetMethodName(IInvocationOperation invocation)

0 commit comments

Comments
 (0)