Skip to content

Commit ad5700b

Browse files
sanych-sunrstam
authored andcommitted
CSHARP-4550: NewExpression and MemberInitExpression behaviour differs (#1089)
1 parent 7746b0f commit ad5700b

File tree

2 files changed

+194
-27
lines changed

2 files changed

+194
-27
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,28 @@ internal static class MemberInitExpressionToAggregationExpressionTranslator
2828
{
2929
public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression)
3030
{
31-
var computedFields = new List<AstComputedField>();
32-
var classMap = CreateClassMap(expression.Type);
33-
3431
var newExpression = expression.NewExpression;
3532
var constructorInfo = newExpression.Constructor;
36-
var constructorParameters = constructorInfo.GetParameters();
3733
var constructorArguments = newExpression.Arguments;
38-
var memberNames = new string[constructorParameters.Length];
39-
for (var i = 0; i < constructorParameters.Length; i++)
40-
{
41-
var constructorParameter = constructorParameters[i];
42-
var memberMap = FindMatchingMemberMap(expression, classMap, constructorParameter);
4334

44-
var argumentExpression = constructorArguments[i];
45-
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
46-
computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, argumentTranslation.Ast));
35+
var classMap = CreateClassMap(expression.Type, constructorInfo, out var creatorMap);
36+
var creatorMapParameters = creatorMap.Arguments?.ToArray();
37+
if (constructorInfo.GetParameters().Length > 0 && creatorMapParameters == null )
38+
{
39+
throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters.");
40+
}
4741

48-
memberMap.SetSerializer(argumentTranslation.Serializer);
49-
memberNames[i] = memberMap.MemberName;
42+
var computedFields = new List<AstComputedField>();
43+
for (var i = 0; i < creatorMapParameters.Length; i++)
44+
{
45+
var creatorMapParameter = creatorMapParameters[i];
46+
var constructorArgumentExpression = constructorArguments[i];
47+
var constructorArgumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, constructorArgumentExpression);
48+
var constructorArgumentType = constructorArgumentExpression.Type;
49+
var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType);
50+
var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter);
51+
memberMap.SetSerializer(constructorArgumentSerializer);
52+
computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast));
5053
}
5154

5255
foreach (var binding in expression.Bindings)
@@ -63,48 +66,69 @@ public static AggregationExpression Translate(TranslationContext context, Member
6366
}
6467

6568
var ast = AstExpression.ComputedDocument(computedFields);
66-
67-
classMap.MapConstructor(constructorInfo, memberNames);
6869
classMap.Freeze();
6970
var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(expression.Type);
7071
var serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap);
7172

7273
return new AggregationExpression(expression, ast, serializer);
7374
}
7475

75-
private static BsonClassMap CreateClassMap(Type classType)
76+
private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap)
7677
{
7778
BsonClassMap baseClassMap = null;
7879
if (classType.BaseType != null)
7980
{
80-
baseClassMap = CreateClassMap(classType.BaseType);
81+
baseClassMap = CreateClassMap(classType.BaseType, null, out _);
8182
}
8283

8384
var classMapType = typeof(BsonClassMap<>).MakeGenericType(classType);
84-
var constructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) });
85-
var classMap = (BsonClassMap)constructorInfo.Invoke(new object[] { baseClassMap });
85+
var classMapConstructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) });
86+
var classMap = (BsonClassMap)classMapConstructorInfo.Invoke(new object[] { baseClassMap });
87+
if (constructorInfo != null)
88+
{
89+
creatorMap = classMap.MapConstructor(constructorInfo);
90+
}
91+
else
92+
{
93+
creatorMap = null;
94+
}
95+
8696
classMap.AutoMap();
8797
classMap.IdMemberMap?.SetElementName("_id"); // normally happens when Freeze is called but we need it sooner here
8898

8999
return classMap;
90100
}
91101

92-
private static BsonMemberMap FindMatchingMemberMap(Expression expression, BsonClassMap classMap, ParameterInfo parameterInfo)
102+
private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter)
93103
{
94-
foreach (var memberMap in classMap.DeclaredMemberMaps)
104+
var declaringClassMap = classMap;
105+
while (declaringClassMap.ClassType != creatorMapParameter.DeclaringType)
95106
{
96-
if (memberMap.MemberType == parameterInfo.ParameterType && memberMap.MemberName.Equals(parameterInfo.Name, StringComparison.OrdinalIgnoreCase))
107+
declaringClassMap = declaringClassMap.BaseClassMap;
108+
109+
if (declaringClassMap == null)
97110
{
98-
return memberMap;
111+
throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching property for constructor parameter: {creatorMapParameter.Name}");
99112
}
100113
}
101114

102-
if (classMap.BaseClassMap != null)
115+
foreach (var memberMap in declaringClassMap.DeclaredMemberMaps)
103116
{
104-
return FindMatchingMemberMap(expression, classMap.BaseClassMap, parameterInfo);
117+
if (MemberMapMatchesCreatorMapParameter(memberMap, creatorMapParameter))
118+
{
119+
return memberMap;
120+
}
105121
}
106122

107-
throw new ExpressionNotSupportedException(expression, because: $"can't find matching property for constructor parameter : {parameterInfo.Name}");
123+
return declaringClassMap.MapMember(creatorMapParameter);
124+
125+
static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberInfo creatorMapParameter)
126+
{
127+
var memberInfo = memberMap.MemberInfo;
128+
return
129+
memberInfo.MemberType == creatorMapParameter.MemberType &&
130+
memberInfo.Name.Equals(creatorMapParameter.Name, StringComparison.OrdinalIgnoreCase);
131+
}
108132
}
109133

110134
private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using FluentAssertions;
18+
using MongoDB.Driver.Linq;
19+
using Xunit;
20+
21+
namespace MongoDB.Driver.Tests.Linq.Linq3ImplementationTests.Translators.ExpressionToAggregationExpressionTranslators
22+
{
23+
public class MemberInitExpressionToAggregationExpressionTranslatorTests : Linq3IntegrationTest
24+
{
25+
private readonly IMongoCollection<MyData> _collection;
26+
27+
public MemberInitExpressionToAggregationExpressionTranslatorTests()
28+
{
29+
_collection = CreateCollection(LinqProvider.V3);
30+
}
31+
32+
[Fact]
33+
public void Should_project_via_parameterless_constructor()
34+
{
35+
var queryable = _collection.AsQueryable()
36+
.Select(x => new SpawnDataParameterless
37+
{
38+
Identifier = x.Id,
39+
SpawnDate = x.Date,
40+
SpawnText = x.Text
41+
});
42+
43+
var stages = Translate(_collection, queryable);
44+
AssertStages(stages, "{ $project : { Identifier : '$_id', SpawnDate : '$Date', SpawnText : '$Text', _id : 0 } }");
45+
46+
var results = queryable.Single();
47+
48+
results.SpawnDate.Should().Be(new DateTime(2023, 1, 2, 3, 4, 5, DateTimeKind.Utc));
49+
results.SpawnText.Should().Be("data text");
50+
results.Identifier.Should().Be(1);
51+
}
52+
53+
[Fact]
54+
public void Should_project_via_constructor()
55+
{
56+
var queryable = _collection.AsQueryable()
57+
.Select(x => new SpawnData(x.Id, x.Date)
58+
{
59+
SpawnText = x.Text
60+
});
61+
62+
var stages = Translate(_collection, queryable);
63+
AssertStages(stages, "{ $project : { Identifier : '$_id', SpawnDate : '$Date', SpawnText : '$Text', _id : 0 } }");
64+
65+
var results = queryable.Single();
66+
67+
results.SpawnDate.Should().Be(new DateTime(2023, 1, 2, 3, 4, 5, DateTimeKind.Utc));
68+
results.SpawnText.Should().Be("data text");
69+
results.Identifier.Should().Be(1);
70+
}
71+
72+
[Fact]
73+
public void Should_project_via_constructor_with_inheritance()
74+
{
75+
var queryable = _collection.AsQueryable()
76+
.Select(x => new InheritedSpawnData(x.Id, x.Date)
77+
{
78+
SpawnText = x.Text
79+
});
80+
81+
var stages = Translate(_collection, queryable);
82+
AssertStages(stages, "{ $project : { Identifier : '$_id', SpawnDate : '$Date', SpawnText : '$Text', _id : 0 } }");
83+
84+
var results = queryable.Single();
85+
86+
results.SpawnDate.Should().Be(new DateTime(2023, 1, 2, 3, 4, 5, DateTimeKind.Utc));
87+
results.SpawnText.Should().Be("data text");
88+
results.Identifier.Should().Be(1);
89+
}
90+
91+
private IMongoCollection<MyData> CreateCollection(LinqProvider linqProvider)
92+
{
93+
var collection = GetCollection<MyData>("data", linqProvider);
94+
95+
CreateCollection(
96+
collection,
97+
new MyData { Id = 1, Date = new DateTime(2023, 1, 2, 3, 4, 5, DateTimeKind.Utc), Text = "data text" });
98+
99+
return collection;
100+
}
101+
102+
public class MyData
103+
{
104+
public int Id { get; set; }
105+
public DateTime Date;
106+
public string Text;
107+
}
108+
109+
public class SpawnDataParameterless
110+
{
111+
public int Identifier;
112+
public DateTime SpawnDate;
113+
public string SpawnText;
114+
}
115+
116+
public class SpawnData
117+
{
118+
public readonly int Identifier;
119+
public DateTime SpawnDate;
120+
private string spawnText;
121+
122+
public SpawnData(int identifier, DateTime spawnDate)
123+
{
124+
Identifier = identifier;
125+
SpawnDate = spawnDate;
126+
}
127+
128+
public string SpawnText
129+
{
130+
get => spawnText;
131+
set => spawnText = value;
132+
}
133+
}
134+
135+
public class InheritedSpawnData : SpawnData
136+
{
137+
public InheritedSpawnData(int identifier, DateTime spawnDate)
138+
: base(identifier, spawnDate)
139+
{
140+
}
141+
}
142+
}
143+
}

0 commit comments

Comments
 (0)