22// The .NET Foundation licenses this file to you under the MIT license.
33// See the LICENSE file in the project root for more information.
44
5+ using System ;
56using System . Collections . Immutable ;
67using System . Diagnostics . CodeAnalysis ;
78using System . Linq ;
1314
1415namespace Microsoft . CodeAnalysis . SimplifyLinqExpression ;
1516
16- internal abstract class AbstractSimplifyLinqExpressionDiagnosticAnalyzer < TInvocationExpressionSyntax , TMemberAccessExpressionSyntax > : AbstractBuiltInCodeStyleDiagnosticAnalyzer
17+ internal abstract class AbstractSimplifyLinqExpressionDiagnosticAnalyzer < TInvocationExpressionSyntax , TMemberAccessExpressionSyntax > ( )
18+ : AbstractBuiltInCodeStyleDiagnosticAnalyzer (
19+ IDEDiagnosticIds . SimplifyLinqExpressionDiagnosticId ,
20+ EnforceOnBuildValues . SimplifyLinqExpression ,
21+ option : null ,
22+ title : new LocalizableResourceString ( nameof ( AnalyzersResources . Simplify_LINQ_expression ) , AnalyzersResources . ResourceManager , typeof ( AnalyzersResources ) ) )
1723 where TInvocationExpressionSyntax : SyntaxNode
1824 where TMemberAccessExpressionSyntax : SyntaxNode
1925{
20- private static readonly ImmutableHashSet < string > s_nonEnumerableReturningLinqMethodNames =
21- ImmutableHashSet . Create (
26+ private static readonly ImmutableHashSet < string > s_nonEnumerableReturningLinqPredicateMethodNames =
27+ [
2228 nameof ( Enumerable . First ) ,
2329 nameof ( Enumerable . Last ) ,
2430 nameof ( Enumerable . Single ) ,
2531 nameof ( Enumerable . Any ) ,
2632 nameof ( Enumerable . Count ) ,
2733 nameof ( Enumerable . SingleOrDefault ) ,
2834 nameof ( Enumerable . FirstOrDefault ) ,
29- nameof ( Enumerable . LastOrDefault ) ) ;
35+ nameof ( Enumerable . LastOrDefault ) ,
36+ ] ;
37+ private static readonly ImmutableHashSet < string > s_nonEnumerableReturningLinqSelectorMethodNames =
38+ [
39+ nameof ( Enumerable . Average ) ,
40+ nameof ( Enumerable . Sum ) ,
41+ nameof ( Enumerable . Min ) ,
42+ nameof ( Enumerable . Max ) ,
43+ ] ;
3044
3145 protected abstract ISyntaxFacts SyntaxFacts { get ; }
3246
3347 protected abstract bool ConflictsWithMemberByNameOnly { get ; }
3448
35- public AbstractSimplifyLinqExpressionDiagnosticAnalyzer ( )
36- : base ( IDEDiagnosticIds . SimplifyLinqExpressionDiagnosticId ,
37- EnforceOnBuildValues . SimplifyLinqExpression ,
38- option : null ,
39- title : new LocalizableResourceString ( nameof ( AnalyzersResources . Simplify_LINQ_expression ) , AnalyzersResources . ResourceManager , typeof ( AnalyzersResources ) ) )
40- {
41- }
42-
4349 protected abstract IInvocationOperation ? TryGetNextInvocationInChain ( IInvocationOperation invocation ) ;
4450
4551 public override DiagnosticAnalyzerCategory GetAnalyzerCategory ( )
@@ -50,18 +56,13 @@ protected override void InitializeWorker(AnalysisContext context)
5056
5157 private void OnCompilationStart ( CompilationStartAnalysisContext context )
5258 {
53- if ( ! TryGetEnumerableTypeSymbol ( context . Compilation , out var enumerableType ) )
54- return ;
55-
56- if ( ! TryGetLinqWhereExtensionMethod ( enumerableType , out var whereMethodSymbol ) )
57- return ;
58-
59- if ( ! TryGetLinqMethodsThatDoNotReturnEnumerables ( enumerableType , out var linqMethodSymbols ) )
60- return ;
61-
62- context . RegisterOperationAction (
63- context => AnalyzeInvocationOperation ( context , enumerableType , whereMethodSymbol , linqMethodSymbols ) ,
64- OperationKind . Invocation ) ;
59+ if ( TryGetEnumerableTypeSymbol ( context . Compilation , out var enumerableType ) &&
60+ TryGetLinqWhereExtensionMethod ( enumerableType , out var whereMethodSymbol ) &&
61+ TryGetLinqSelectExtensionMethod ( enumerableType , out var selectMethodSymbol ) &&
62+ TryGetLinqMethodsThatDoNotReturnEnumerables ( enumerableType , out var linqMethods ) )
63+ {
64+ context . RegisterOperationAction ( AnalyzeInvocationOperation , OperationKind . Invocation ) ;
65+ }
6566
6667 return ;
6768
@@ -71,31 +72,40 @@ static bool TryGetEnumerableTypeSymbol(Compilation compilation, [NotNullWhen(tru
7172 return enumerableType is not null ;
7273 }
7374
74- static bool TryGetLinqWhereExtensionMethod ( INamedTypeSymbol enumerableType , [ NotNullWhen ( true ) ] out IMethodSymbol ? whereMethod )
75+ static bool TryGetLinqWhereExtensionMethod ( INamedTypeSymbol enumerableType , [ NotNullWhen ( true ) ] out IMethodSymbol ? linqMethod )
76+ => TryGetLinqExtensionMethod ( enumerableType , nameof ( Enumerable . Where ) , out linqMethod ) ;
77+
78+ static bool TryGetLinqSelectExtensionMethod ( INamedTypeSymbol enumerableType , [ NotNullWhen ( true ) ] out IMethodSymbol ? linqMethod )
79+ => TryGetLinqExtensionMethod ( enumerableType , nameof ( Enumerable . Select ) , out linqMethod ) ;
80+
81+ static bool TryGetLinqExtensionMethod ( INamedTypeSymbol enumerableType , string name , [ NotNullWhen ( true ) ] out IMethodSymbol ? linqMethod )
7582 {
76- foreach ( var whereMethodSymbol in enumerableType . GetMembers ( nameof ( Enumerable . Where ) ) . OfType < IMethodSymbol > ( ) )
83+ foreach ( var linqMethodSymbol in enumerableType . GetMembers ( name ) . OfType < IMethodSymbol > ( ) )
7784 {
78- var parameters = whereMethodSymbol . Parameters ;
79-
80- if ( parameters is [ _, { Type : INamedTypeSymbol { Arity : 2 } } ] )
85+ if ( linqMethodSymbol . Parameters is [ _, { Type : INamedTypeSymbol { Arity : 2 } } ] )
8186 {
82- // This is the where overload that does not take and index (i.e. Where(source, Func<T, bool>) vs Where(source, Func<T, int, bool>))
83- whereMethod = whereMethodSymbol ;
87+ // This is the Where/Select overload that does not take and index (i.e. Where(source, Func<T, bool>)
88+ // vs Where(source, Func<T, int, bool>))
89+ linqMethod = linqMethodSymbol ;
8490 return true ;
8591 }
8692 }
8793
88- whereMethod = null ;
94+ linqMethod = null ;
8995 return false ;
9096 }
9197
9298 static bool TryGetLinqMethodsThatDoNotReturnEnumerables ( INamedTypeSymbol enumerableType , out ImmutableArray < IMethodSymbol > linqMethods )
9399 {
94100 using var _ = ArrayBuilder < IMethodSymbol > . GetInstance ( out var linqMethodSymbolsBuilder ) ;
101+
95102 foreach ( var method in enumerableType . GetMembers ( ) . OfType < IMethodSymbol > ( ) )
96103 {
97- if ( s_nonEnumerableReturningLinqMethodNames . Contains ( method . Name ) &&
98- method . Parameters is { Length : 1 } )
104+ if ( method . Parameters . Length != 1 )
105+ continue ;
106+
107+ if ( s_nonEnumerableReturningLinqPredicateMethodNames . Contains ( method . Name ) ||
108+ s_nonEnumerableReturningLinqSelectorMethodNames . Contains ( method . Name ) )
99109 {
100110 linqMethodSymbolsBuilder . AddRange ( method ) ;
101111 }
@@ -104,65 +114,76 @@ static bool TryGetLinqMethodsThatDoNotReturnEnumerables(INamedTypeSymbol enumera
104114 linqMethods = linqMethodSymbolsBuilder . ToImmutable ( ) ;
105115 return linqMethods . Any ( ) ;
106116 }
107- }
108-
109- public void AnalyzeInvocationOperation ( OperationAnalysisContext context , INamedTypeSymbol enumerableType , IMethodSymbol whereMethod , ImmutableArray < IMethodSymbol > linqMethods )
110- {
111- if ( ShouldSkipAnalysis ( context , notification : null ) )
112- return ;
113117
114- if ( context . Operation . Syntax . GetDiagnostics ( ) . Any ( diagnostic => diagnostic . Severity == DiagnosticSeverity . Error ) )
118+ void AnalyzeInvocationOperation ( OperationAnalysisContext context )
115119 {
120+ if ( ShouldSkipAnalysis ( context , notification : null ) )
121+ return ;
122+
116123 // Do not analyze linq methods that contain diagnostics.
117- return ;
118- }
124+ if ( context . Operation . Syntax . GetDiagnostics ( ) . Any ( diagnostic => diagnostic . Severity == DiagnosticSeverity . Error ) )
125+ return ;
119126
120- if ( context . Operation is not IInvocationOperation invocation ||
121- ! IsWhereLinqMethod ( invocation ) )
122- {
123- // we only care about Where methods on linq expressions
124- return ;
125- }
127+ // we only care about Where/Select invocation methods on linq expressions
126128
127- if ( TryGetNextInvocationInChain ( invocation ) is not IInvocationOperation nextInvocation ||
128- ! IsInvocationNonEnumerableReturningLinqMethod ( nextInvocation ) )
129- {
130- // Invocation is not part of a chain of invocations (i.e. Where(x => x is not null).First())
131- return ;
132- }
129+ if ( context . Operation is not IInvocationOperation invocation )
130+ return ;
133131
134- if ( TryGetSymbolOfMemberAccess ( invocation ) is not INamedTypeSymbol targetTypeSymbol ||
135- TryGetMethodName ( nextInvocation ) is not string name )
136- {
137- return ;
138- }
132+ var isWhereMethod = IsWhereLinqMethod ( invocation ) ;
133+ var isSelectMethod = IsSelectLinqMethod ( invocation ) ;
134+ if ( ! isWhereMethod && ! isSelectMethod )
135+ return ;
139136
140- // Do not offer to transpose if there is already a method on the collection named the same as the linq extension
141- // method. This would cause us to call the instance method after the transformation, not the extension method.
142- if ( ! targetTypeSymbol . Equals ( enumerableType , SymbolEqualityComparer . Default ) &&
143- targetTypeSymbol . MemberNames . Contains ( name ) )
144- {
145- // VB conflicts if any member has the same name (like a Count property vs Count extension method).
146- if ( this . ConflictsWithMemberByNameOnly )
137+ if ( TryGetNextInvocationInChain ( invocation ) is not IInvocationOperation nextInvocation ||
138+ ! IsInvocationNonEnumerableReturningLinqMethod ( nextInvocation ) )
139+ {
140+ // Invocation is not part of a chain of invocations (i.e. Where(x => x is not null).First())
147141 return ;
142+ }
148143
149- // C# conflicts only if it is a method as well. So a Count property will not conflict with a Count
150- // extension method.
151- if ( targetTypeSymbol . GetMembers ( name ) . Any ( m => m is IMethodSymbol ) )
144+ if ( TryGetSymbolOfMemberAccess ( invocation ) is not ITypeSymbol targetTypeSymbol ||
145+ TryGetMethodName ( nextInvocation ) is not string name )
146+ {
152147 return ;
153- }
148+ }
154149
155- context . ReportDiagnostic ( Diagnostic . Create ( Descriptor , nextInvocation . Syntax . GetLocation ( ) ) ) ;
150+ if ( isWhereMethod && ! s_nonEnumerableReturningLinqPredicateMethodNames . Contains ( name ) )
151+ return ;
156152
157- return ;
153+ if ( isSelectMethod && ! s_nonEnumerableReturningLinqSelectorMethodNames . Contains ( name ) )
154+ return ;
155+
156+ // Do not offer to transpose if there is already a method on the collection named the same as the linq extension
157+ // method. This would cause us to call the instance method after the transformation, not the extension method.
158+ if ( ! targetTypeSymbol . Equals ( enumerableType , SymbolEqualityComparer . Default ) )
159+ {
160+ var members = targetTypeSymbol . GetMembers ( name ) ;
161+ if ( members . Length > 0 )
162+ {
163+ // VB conflicts if any member has the same name (like a Count property vs Count extension method).
164+ if ( this . ConflictsWithMemberByNameOnly )
165+ return ;
166+
167+ // C# conflicts only if it is a method as well. So a Count property will not conflict with a Count
168+ // extension method.
169+ if ( members . Any ( m => m is IMethodSymbol ) )
170+ return ;
171+ }
172+ }
173+
174+ context . ReportDiagnostic ( Diagnostic . Create ( Descriptor , nextInvocation . Syntax . GetLocation ( ) ) ) ;
175+ }
158176
159177 bool IsWhereLinqMethod ( IInvocationOperation invocation )
160- => whereMethod . Equals ( invocation . TargetMethod . ReducedFrom ?? invocation . TargetMethod . OriginalDefinition , SymbolEqualityComparer . Default ) ;
178+ => whereMethodSymbol . Equals ( invocation . TargetMethod . ReducedFrom ?? invocation . TargetMethod . OriginalDefinition , SymbolEqualityComparer . Default ) ;
179+
180+ bool IsSelectLinqMethod ( IInvocationOperation invocation )
181+ => selectMethodSymbol . Equals ( invocation . TargetMethod . ReducedFrom ?? invocation . TargetMethod . OriginalDefinition , SymbolEqualityComparer . Default ) ;
161182
162183 bool IsInvocationNonEnumerableReturningLinqMethod ( IInvocationOperation invocation )
163184 => linqMethods . Any ( static ( m , invocation ) => m . Equals ( invocation . TargetMethod . ReducedFrom ?? invocation . TargetMethod . OriginalDefinition , SymbolEqualityComparer . Default ) , invocation ) ;
164185
165- INamedTypeSymbol ? TryGetSymbolOfMemberAccess ( IInvocationOperation invocation )
186+ ITypeSymbol ? TryGetSymbolOfMemberAccess ( IInvocationOperation invocation )
166187 {
167188 if ( invocation . Syntax is not TInvocationExpressionSyntax invocationNode ||
168189 SyntaxFacts . GetExpressionOfInvocationExpression ( invocationNode ) is not TMemberAccessExpressionSyntax memberAccess ||
@@ -171,7 +192,7 @@ bool IsInvocationNonEnumerableReturningLinqMethod(IInvocationOperation invocatio
171192 return null ;
172193 }
173194
174- return invocation . SemanticModel ? . GetTypeInfo ( expression ) . Type as INamedTypeSymbol ;
195+ return invocation . SemanticModel ? . GetTypeInfo ( expression ) . Type ;
175196 }
176197
177198 string ? TryGetMethodName ( IInvocationOperation invocation )
0 commit comments