Skip to content

Commit 96cd39d

Browse files
Fix Queryable-to-Enumerable overload mapping logic (#65569)
* Fix Queryable-to-Enumerable overload mapping logic * fix linker warnings * address feedback * use strict order when calculating maximal elements
1 parent b3465af commit 96cd39d

File tree

4 files changed

+103
-9
lines changed

4 files changed

+103
-9
lines changed

src/libraries/System.Linq.Queryable/src/System/Linq/EnumerableRewriter.cs

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,25 +255,86 @@ protected override Expression VisitConstant(ConstantExpression c)
255255
}
256256

257257
private static ILookup<string, MethodInfo>? s_seqMethods;
258-
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2060:MakeGenericMethod",
259-
Justification = "Enumerable methods don't have trim annotations.")]
260258
private static MethodInfo FindEnumerableMethodForQueryable(string name, ReadOnlyCollection<Expression> args, params Type[]? typeArgs)
261259
{
262-
if (s_seqMethods == null)
260+
s_seqMethods ??= GetEnumerableStaticMethods(typeof(Enumerable)).ToLookup(m => m.Name);
261+
262+
MethodInfo[] matchingMethods = s_seqMethods[name]
263+
.Where(m => ArgsMatch(m, args, typeArgs))
264+
.Select(ApplyTypeArgs)
265+
.ToArray();
266+
267+
Debug.Assert(matchingMethods.Length > 0, "All static methods with arguments on Queryable have equivalents on Enumerable.");
268+
269+
if (matchingMethods.Length > 1)
263270
{
264-
s_seqMethods = GetEnumerableStaticMethods(typeof(Enumerable)).ToLookup(m => m.Name);
271+
return DisambiguateMatches(matchingMethods);
265272
}
266-
MethodInfo? mi = s_seqMethods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
267-
Debug.Assert(mi != null, "All static methods with arguments on Queryable have equivalents on Enumerable.");
268-
if (typeArgs != null)
269-
return mi.MakeGenericMethod(typeArgs);
270-
return mi;
273+
274+
return matchingMethods[0];
271275

272276
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern",
273277
Justification = "This method is intentionally hiding the Enumerable type from the trimmer so it doesn't preserve all Enumerable's methods. " +
274278
"This is safe because all Queryable methods have a DynamicDependency to the corresponding Enumerable method.")]
275279
static MethodInfo[] GetEnumerableStaticMethods(Type type) =>
276280
type.GetMethods(BindingFlags.Public | BindingFlags.Static);
281+
282+
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2060:MakeGenericMethod",
283+
Justification = "Enumerable methods don't have trim annotations.")]
284+
MethodInfo ApplyTypeArgs(MethodInfo methodInfo) => typeArgs == null ? methodInfo : methodInfo.MakeGenericMethod(typeArgs);
285+
286+
// In certain cases, there might be ambiguities when resolving matching overloads, for example between
287+
// 1. FirstOrDefault<object>(IEnumerable<object> source, Func<object, bool> predicate) and
288+
// 2. FirstOrDefault<object>(IEnumerable<object> source, object defaultvalue).
289+
// In such cases we disambiguate by picking a method with the most derived signature.
290+
static MethodInfo DisambiguateMatches(MethodInfo[] matchingMethods)
291+
{
292+
Debug.Assert(matchingMethods.Length > 1);
293+
ParameterInfo[][] parameters = matchingMethods.Select(m => m.GetParameters()).ToArray();
294+
295+
// `AreAssignableFrom[Strict]` defines a partial order on method signatures; pick a maximal element using that order.
296+
// It is assumed that `matchingMethods` is a small array, so a naive quadratic search is probably better than
297+
// doing some variant of topological sorting.
298+
299+
for (int i = 0; i < matchingMethods.Length; i++)
300+
{
301+
bool isMaximal = true;
302+
for (int j = 0; j < matchingMethods.Length; j++)
303+
{
304+
if (i != j && AreAssignableFromStrict(parameters[i], parameters[j]))
305+
{
306+
// Found a matching method that contains strictly more specific parameter types.
307+
isMaximal = false;
308+
break;
309+
}
310+
}
311+
312+
if (isMaximal)
313+
{
314+
return matchingMethods[i];
315+
}
316+
}
317+
318+
Debug.Fail("Search should have found a maximal element");
319+
throw new Exception();
320+
321+
static bool AreAssignableFromStrict(ParameterInfo[] left, ParameterInfo[] right)
322+
{
323+
Debug.Assert(left.Length == right.Length);
324+
325+
bool areEqual = true;
326+
bool areAssignableFrom = true;
327+
for (int i = 0; i < left.Length; i++)
328+
{
329+
Type leftParam = left[i].ParameterType;
330+
Type rightParam = right[i].ParameterType;
331+
areEqual = areEqual && leftParam == rightParam;
332+
areAssignableFrom = areAssignableFrom && leftParam.IsAssignableFrom(rightParam);
333+
}
334+
335+
return !areEqual && areAssignableFrom;
336+
}
337+
}
277338
}
278339

279340
[RequiresUnreferencedCode(Queryable.InMemoryQueryableExtensionMethodsRequiresUnreferencedCode)]

src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,5 +129,16 @@ public void FirstOrDefault2()
129129
var val = (new int[] { 0, 1, 2 }).AsQueryable().FirstOrDefault(n => n > 1);
130130
Assert.Equal(2, val);
131131
}
132+
133+
[Fact]
134+
public void FirstOrDefault_OverloadResolution_Regression()
135+
{
136+
// Regression test for https://github.com/dotnet/runtime/issues/65419
137+
object? result = new object[] { 1, "" }.AsQueryable().FirstOrDefault(x => x is string);
138+
Assert.IsType<string>(result);
139+
140+
result = Array.Empty<object>().AsQueryable().FirstOrDefault(1);
141+
Assert.IsType<int>(result);
142+
}
132143
}
133144
}

src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,16 @@ public void LastOrDefault2()
9696
var val = (new int[] { 0, 1, 2 }).AsQueryable().LastOrDefault(n => n > 1);
9797
Assert.Equal(2, val);
9898
}
99+
100+
[Fact]
101+
public void LastOrDefault_OverloadResolution_Regression()
102+
{
103+
// Regression test for https://github.com/dotnet/runtime/issues/65419
104+
object? result = new object[] { 1, "" }.AsQueryable().LastOrDefault(x => x is int);
105+
Assert.IsType<int>(result);
106+
107+
result = Array.Empty<object>().AsQueryable().LastOrDefault(1);
108+
Assert.IsType<int>(result);
109+
}
99110
}
100111
}

src/libraries/System.Linq.Queryable/tests/SingleOrDefaultTests.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,16 @@ public void SingleOrDefault2()
7979
var val = (new int[] { 2 }).AsQueryable().SingleOrDefault(n => n > 1);
8080
Assert.Equal(2, val);
8181
}
82+
83+
[Fact]
84+
public void SingleOrDefault_OverloadResolution_Regression()
85+
{
86+
// Regression test for https://github.com/dotnet/runtime/issues/65419
87+
object? result = new object[] { 1, "" }.AsQueryable().SingleOrDefault(x => x is string);
88+
Assert.IsType<string>(result);
89+
90+
result = Array.Empty<object>().AsQueryable().SingleOrDefault(1);
91+
Assert.IsType<int>(result);
92+
}
8293
}
8394
}

0 commit comments

Comments
 (0)