Skip to content

Commit

Permalink
Collection expressions: use conversion to check for ICollection<T> fo…
Browse files Browse the repository at this point in the history
…r spread optimization (#74949)
  • Loading branch information
cston authored Aug 30, 2024
1 parent 2805273 commit 9907b79
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,8 @@ private BoundExpression CreateAndPopulateList(BoundCollectionExpression node, Ty
if (addRangeMethod is null)
return false;
Conversion conversion;
if (spreadElement.EnumeratorInfoOpt is { } enumeratorInfo)
{
var iCollectionOfTType = _compilation.GetSpecialType(SpecialType.System_Collections_Generic_ICollection_T);
Expand All @@ -1075,19 +1077,22 @@ private BoundExpression CreateAndPopulateList(BoundCollectionExpression node, Ty
// If collection has a struct enumerator but doesn't implement ICollection<T>
// then manual `foreach` is always more efficient then using `AddRange` method
if (enumeratorInfo.GetEnumeratorInfo.Method.ReturnType.IsValueType &&
!enumeratorInfo.CollectionType.ImplementsInterface(iCollectionOfElementType, ref discardedUseSiteInfo))
if (enumeratorInfo.GetEnumeratorInfo.Method.ReturnType.IsValueType)
{
return false;
conversion = _compilation.Conversions.ClassifyBuiltInConversion(enumeratorInfo.CollectionType, iCollectionOfElementType, isChecked: false, ref discardedUseSiteInfo);
if (!(conversion.Kind is ConversionKind.Identity or ConversionKind.ImplicitReference))
{
return false;
}
}
}
var type = rewrittenSpreadOperand.Type!;
var useSiteInfo = GetNewCompoundUseSiteInfo();
var conversion = _compilation.Conversions.ClassifyConversionFromType(type, addRangeMethod.Parameters[0].Type, isChecked: false, ref useSiteInfo);
conversion = _compilation.Conversions.ClassifyConversionFromType(type, addRangeMethod.Parameters[0].Type, isChecked: false, ref useSiteInfo);
_diagnostics.Add(rewrittenSpreadOperand.Syntax, useSiteInfo);
if (conversion.IsIdentity || (conversion.IsImplicit && conversion.IsReference))
if (conversion.Kind is ConversionKind.Identity or ConversionKind.ImplicitReference)
{
conversion.MarkUnderlyingConversionsCheckedRecursive();
sideEffects.Add(_factory.Call(listTemp, addRangeMethod, rewrittenSpreadOperand));
Expand Down
179 changes: 179 additions & 0 deletions src/Compilers/CSharp/Test/Emit2/Semantics/CollectionExpressionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35267,6 +35267,185 @@ .locals init (int V_0,
verifier.VerifyIL("C.M", expectedIL);
}

[WorkItem("https://github.com/dotnet/roslyn/issues/74894")]
[Fact]
public void List_Spread_StructEnumerator()
{
var source = """
using System.Collections.Generic;
interface IMyEnumerable<T> : IEnumerable<T>
{
new MyEnumerator<T> GetEnumerator();
}
interface IMyCollection<T> : ICollection<T>
{
new MyEnumerator<T> GetEnumerator();
}
struct MyEnumerator<T>
{
private readonly List<T> _list;
private int _index;
public MyEnumerator(List<T> list) { _list = list; _index = -1; }
public T Current => _list[_index];
public bool MoveNext()
{
if (_index < _list.Count) _index++;
return _index < _list.Count;
}
}
class MyEnumerable<T> : List<T>, IMyEnumerable<T>
{
MyEnumerator<T> IMyEnumerable<T>.GetEnumerator() => new(this);
}
class MyCollection<T> : List<T>, IMyCollection<T>
{
MyEnumerator<T> IMyCollection<T>.GetEnumerator() => new(this);
}
class Program
{
static void Main()
{
MyEnumerable<object> x = [1, 2];
MyCollection<object> y = [1, 2];
object z = 3;
M1(x, z).Report();
M2(y, z).Report();
M3(x, z).Report();
M4(y, z).Report();
}
static List<object> M1(IMyEnumerable<object> x, object y)
{
return [..x, y];
}
static List<object> M2(IMyCollection<object> x, object y)
{
return [..x, y];
}
#nullable enable
static List<U?> M3<T, U>(T x, U y)
where T : class, IMyEnumerable<U>
where U : class
{
return [..x, y];
}
static List<U?> M4<T, U>(T x, U y)
where T : class, IMyCollection<U>
where U : class
{
return [..x, y];
}
}
""";
var verifier = CompileAndVerify(
[source, s_collectionExtensions],
targetFramework: TargetFramework.Standard,
expectedOutput: "[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3], ",
verify: Verification.Skipped);
verifier.VerifyDiagnostics();
verifier.VerifyIL("Program.M1", """
{
// Code size 48 (0x30)
.maxstack 2
.locals init (System.Collections.Generic.List<object> V_0,
MyEnumerator<object> V_1,
object V_2)
IL_0000: newobj "System.Collections.Generic.List<object>..ctor()"
IL_0005: stloc.0
IL_0006: ldarg.0
IL_0007: callvirt "MyEnumerator<object> IMyEnumerable<object>.GetEnumerator()"
IL_000c: stloc.1
IL_000d: br.s IL_001e
IL_000f: ldloca.s V_1
IL_0011: call "object MyEnumerator<object>.Current.get"
IL_0016: stloc.2
IL_0017: ldloc.0
IL_0018: ldloc.2
IL_0019: callvirt "void System.Collections.Generic.List<object>.Add(object)"
IL_001e: ldloca.s V_1
IL_0020: call "bool MyEnumerator<object>.MoveNext()"
IL_0025: brtrue.s IL_000f
IL_0027: ldloc.0
IL_0028: ldarg.1
IL_0029: callvirt "void System.Collections.Generic.List<object>.Add(object)"
IL_002e: ldloc.0
IL_002f: ret
}
""");
verifier.VerifyIL("Program.M2", """
{
// Code size 30 (0x1e)
.maxstack 3
.locals init (IMyCollection<object> V_0)
IL_0000: ldarg.0
IL_0001: stloc.0
IL_0002: ldc.i4.1
IL_0003: ldloc.0
IL_0004: callvirt "int System.Collections.Generic.ICollection<object>.Count.get"
IL_0009: add
IL_000a: newobj "System.Collections.Generic.List<object>..ctor(int)"
IL_000f: dup
IL_0010: ldloc.0
IL_0011: callvirt "void System.Collections.Generic.List<object>.AddRange(System.Collections.Generic.IEnumerable<object>)"
IL_0016: dup
IL_0017: ldarg.1
IL_0018: callvirt "void System.Collections.Generic.List<object>.Add(object)"
IL_001d: ret
}
""");
verifier.VerifyIL("Program.M3<T, U>(T, U)", """
{
// Code size 53 (0x35)
.maxstack 2
.locals init (System.Collections.Generic.List<U> V_0,
MyEnumerator<U> V_1,
U V_2)
IL_0000: newobj "System.Collections.Generic.List<U>..ctor()"
IL_0005: stloc.0
IL_0006: ldarg.0
IL_0007: box "T"
IL_000c: callvirt "MyEnumerator<U> IMyEnumerable<U>.GetEnumerator()"
IL_0011: stloc.1
IL_0012: br.s IL_0023
IL_0014: ldloca.s V_1
IL_0016: call "U MyEnumerator<U>.Current.get"
IL_001b: stloc.2
IL_001c: ldloc.0
IL_001d: ldloc.2
IL_001e: callvirt "void System.Collections.Generic.List<U>.Add(U)"
IL_0023: ldloca.s V_1
IL_0025: call "bool MyEnumerator<U>.MoveNext()"
IL_002a: brtrue.s IL_0014
IL_002c: ldloc.0
IL_002d: ldarg.1
IL_002e: callvirt "void System.Collections.Generic.List<U>.Add(U)"
IL_0033: ldloc.0
IL_0034: ret
}
""");
verifier.VerifyIL("Program.M4<T, U>(T, U)", """
{
// Code size 35 (0x23)
.maxstack 3
.locals init (T V_0)
IL_0000: ldarg.0
IL_0001: stloc.0
IL_0002: ldc.i4.1
IL_0003: ldloc.0
IL_0004: box "T"
IL_0009: callvirt "int System.Collections.Generic.ICollection<U>.Count.get"
IL_000e: add
IL_000f: newobj "System.Collections.Generic.List<U>..ctor(int)"
IL_0014: dup
IL_0015: ldloc.0
IL_0016: callvirt "void System.Collections.Generic.List<U>.AddRange(System.Collections.Generic.IEnumerable<U>)"
IL_001b: dup
IL_001c: ldarg.1
IL_001d: callvirt "void System.Collections.Generic.List<U>.Add(U)"
IL_0022: ret
}
""");
}

[Fact]
public void List_AddRange_IEnumerable()
{
Expand Down

0 comments on commit 9907b79

Please sign in to comment.