diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs index 3e4d72745fe..be03832fa6d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs @@ -16,7 +16,6 @@ using MongoDB.Bson; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters; -using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors; namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers @@ -37,6 +36,48 @@ public static TNode SimplifyAndConvert(TNode node) } #endregion + public override AstNode VisitCondExpression(AstCondExpression node) + { + // { $cond : [{ $eq : [expr1, null] }, null, expr2] } + if (node.If is AstBinaryExpression binaryIfpression && + binaryIfpression.Operator == AstBinaryOperator.Eq && + binaryIfpression.Arg1 is AstExpression expr1 && + binaryIfpression.Arg2 is AstConstantExpression constantComparandExpression && + constantComparandExpression.Value == BsonNull.Value && + node.Then is AstConstantExpression constantThenExpression && + constantThenExpression.Value == BsonNull.Value && + node.Else is AstExpression expr2) + { + // { $cond : [{ $eq : [expr, null] }, null, expr] } => expr + if (expr1 == expr2) + { + return Visit(expr2); + } + + // { $cond : [{ $eq : [expr, null] }, null, { $toT : expr }] } => { $toT : expr } for operators that map null to null + if (expr2 is AstUnaryExpression unaryElseExpression && + OperatorMapsNullToNull(unaryElseExpression.Operator) && + unaryElseExpression.Arg == expr1) + { + return Visit(expr2); + } + } + + return base.VisitCondExpression(node); + + static bool OperatorMapsNullToNull(AstUnaryOperator @operator) + { + return @operator switch + { + AstUnaryOperator.ToDecimal => true, + AstUnaryOperator.ToDouble => true, + AstUnaryOperator.ToInt => true, + AstUnaryOperator.ToLong => true, + _ => false + }; + } + } + public override AstNode VisitFieldOperationFilter(AstFieldOperationFilter node) { node = (AstFieldOperationFilter)base.VisitFieldOperationFilter(node); @@ -281,6 +322,22 @@ bool TrySimplifyAsLet(AstGetFieldExpression node, out AstExpression simplified) } } + public override AstNode VisitLetExpression(AstLetExpression node) + { + node = (AstLetExpression)base.VisitLetExpression(node); + + // { $let : { vars : { var : expr }, in : "$$var" } } => expr + if (node.Vars.Count == 1 && + node.Vars[0].Var.Name is string varName && + node.In is AstVarExpression varExpression && + varExpression.Name == varName) + { + return node.Vars[0].Value; + } + + return node; + } + public override AstNode VisitMapExpression(AstMapExpression node) { // { $map : { input : , as : "v", in : "$$v.x" } } => { $getField : { field : "x", input : } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs index 8596f64e2b0..e511e24bc6a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs @@ -30,110 +30,121 @@ public static AggregationExpression Translate(TranslationContext context, UnaryE { if (expression.NodeType == ExpressionType.Convert || expression.NodeType == ExpressionType.TypeAs) { - var expressionType = expression.Type; - if (expressionType == typeof(BsonValue)) + var sourceExpression = expression.Operand; + var sourceType = sourceExpression.Type; + var targetType = expression.Type; + + // handle double conversions like `(BsonValue)(object)x` + if (targetType == typeof(BsonValue) && + sourceExpression is UnaryExpression unarySourceExpression && + unarySourceExpression.NodeType == ExpressionType.Convert && + unarySourceExpression.Type == typeof(object)) { - return TranslateConvertToBsonValue(context, expression, expression.Operand); + sourceExpression = unarySourceExpression.Operand; } - var operandExpression = expression.Operand; - var operandTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, operandExpression); + var sourceTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, sourceExpression); + return Translate(expression, sourceType, targetType, sourceTranslation); + } - if (expressionType == operandExpression.Type) - { - return operandTranslation; - } + throw new ExpressionNotSupportedException(expression); + } - if (IsConvertEnumToUnderlyingType(expression)) - { - return TranslateConvertEnumToUnderlyingType(expression, operandTranslation); - } + private static AggregationExpression Translate(UnaryExpression expression, Type sourceType, Type targetType, AggregationExpression sourceTranslation) + { + if (targetType == sourceType) + { + return sourceTranslation; + } - if (IsConvertUnderlyingTypeToEnum(expression)) - { - return TranslateConvertUnderlyingTypeToEnum(expression, operandTranslation); - } + // from Nullable must be handled before to Nullable + if (IsConvertFromNullableType(sourceType)) + { + return TranslateConvertFromNullableType(expression, sourceType, targetType, sourceTranslation); + } - if (IsConvertEnumToEnum(expression)) - { - return TranslateConvertEnumToEnum(expression, operandTranslation); - } + if (IsConvertToNullableType(targetType)) + { + return TranslateConvertToNullableType(expression, sourceType, targetType, sourceTranslation); + } - if (IsConvertToBaseType(sourceType: operandExpression.Type, targetType: expressionType)) - { - return TranslateConvertToBaseType(expression, operandTranslation); - } + // from here on we know there are no longer any Nullable types involved - if (IsConvertToDerivedType(sourceType: operandExpression.Type, targetType: expressionType)) - { - return TranslateConvertToDerivedType(expression, operandTranslation); - } + if (targetType == typeof(BsonValue)) + { + return TranslateConvertToBsonValue(expression, sourceTranslation); + } - if (expressionType.IsConstructedGenericType && expressionType.GetGenericTypeDefinition() == typeof(Nullable<>)) - { - var valueType = expressionType.GetGenericArguments()[0]; - if (operandExpression.Type == valueType) - { - // use the same AST but with a new nullable serializer - var nullableSerializerType = typeof(NullableSerializer<>).MakeGenericType(valueType); - var valueSerializerType = typeof(IBsonSerializer<>).MakeGenericType(valueType); - var constructorInfo = nullableSerializerType.GetConstructor(new[] { valueSerializerType }); - var nullableSerializer = (IBsonSerializer)constructorInfo.Invoke(new[] { operandTranslation.Serializer }); - return new AggregationExpression(expression, operandTranslation.Ast, nullableSerializer); - } - } + if (IsConvertEnumToUnderlyingType(sourceType, targetType)) + { + return TranslateConvertEnumToUnderlyingType(expression, sourceType, targetType, sourceTranslation); + } - var ast = operandTranslation.Ast; - IBsonSerializer serializer; - if (expressionType.IsInterface) - { - // when an expression is cast to an interface it's a no-op as far as we're concerned - // and we can just use the serializer for the concrete type and members not defined in the interface will just be ignored - serializer = operandTranslation.Serializer; - } - else + if (IsConvertUnderlyingTypeToEnum(sourceType, targetType)) + { + return TranslateConvertUnderlyingTypeToEnum(expression, sourceType, targetType, sourceTranslation); + } + + if (IsConvertEnumToEnum(sourceType, targetType)) + { + return TranslateConvertEnumToEnum(expression, sourceType, targetType, sourceTranslation); + } + + if (IsConvertToBaseType(sourceType, targetType)) + { + return TranslateConvertToBaseType(expression, sourceType, targetType, sourceTranslation); + } + + if (IsConvertToDerivedType(sourceType, targetType)) + { + return TranslateConvertToDerivedType(expression, targetType, sourceTranslation); + } + + var ast = sourceTranslation.Ast; + IBsonSerializer serializer; + if (targetType.IsInterface) + { + // when an expression is cast to an interface it's a no-op as far as we're concerned + // and we can just use the serializer for the concrete type and members not defined in the interface will just be ignored + serializer = sourceTranslation.Serializer; + } + else + { + AstExpression to; + switch (targetType.FullName) { - AstExpression to; - switch (expressionType.FullName) - { - case "MongoDB.Bson.ObjectId": to = "objectId"; serializer = ObjectIdSerializer.Instance; break; - case "System.Boolean": to = "bool"; serializer = BooleanSerializer.Instance; break; - case "System.DateTime": to = "date"; serializer = DateTimeSerializer.Instance; break; - case "System.Decimal": to = "decimal"; serializer = DecimalSerializer.Decimal128Instance; break; // not the default representation - case "System.Double": to = "double"; serializer = DoubleSerializer.Instance; break; - case "System.Int32": to = "int"; serializer = Int32Serializer.Instance; break; - case "System.Int64": to = "long"; serializer = Int64Serializer.Instance; break; - case "System.String": to = "string"; serializer = StringSerializer.Instance; break; - default: throw new ExpressionNotSupportedException(expression, because: $"conversion to {expressionType} is not supported"); - } - - ast = AstExpression.Convert(ast, to); + case "MongoDB.Bson.ObjectId": to = "objectId"; serializer = ObjectIdSerializer.Instance; break; + case "System.Boolean": to = "bool"; serializer = BooleanSerializer.Instance; break; + case "System.DateTime": to = "date"; serializer = DateTimeSerializer.Instance; break; + case "System.Decimal": to = "decimal"; serializer = DecimalSerializer.Decimal128Instance; break; // not the default representation + case "System.Double": to = "double"; serializer = DoubleSerializer.Instance; break; + case "System.Int32": to = "int"; serializer = Int32Serializer.Instance; break; + case "System.Int64": to = "long"; serializer = Int64Serializer.Instance; break; + case "System.String": to = "string"; serializer = StringSerializer.Instance; break; + default: throw new ExpressionNotSupportedException(expression, because: $"conversion to {targetType} is not supported"); } - return new AggregationExpression(expression, ast, serializer); + ast = AstExpression.Convert(ast, to); } - throw new ExpressionNotSupportedException(expression); + return new AggregationExpression(expression, ast, serializer); } - private static bool IsConvertEnumToEnum(UnaryExpression expression) + private static bool IsConvertEnumToEnum(Type sourceType, Type targetType) { - var sourceType = expression.Operand.Type; - var targetType = expression.Type; + return sourceType.IsEnum && targetType.IsEnum; + } + private static bool IsConvertEnumToUnderlyingType(Type sourceType, Type targetType) + { return - sourceType.IsEnumOrNullableEnum(out _, out _) && - targetType.IsEnumOrNullableEnum(out _, out _); + sourceType.IsEnum(out var underlyingType) && + targetType == underlyingType; } - private static bool IsConvertEnumToUnderlyingType(UnaryExpression expression) + private static bool IsConvertFromNullableType(Type sourceType) { - var sourceType = expression.Operand.Type; - var targetType = expression.Type; - - return - sourceType.IsEnumOrNullableEnum(out _, out var underlyingType) && - targetType.IsSameAsOrNullableOf(underlyingType); + return sourceType.IsNullable(); } private static bool IsConvertToBaseType(Type sourceType, Type targetType) @@ -146,127 +157,112 @@ private static bool IsConvertToDerivedType(Type sourceType, Type targetType) return targetType.IsSubclassOf(sourceType); } - private static bool IsConvertUnderlyingTypeToEnum(UnaryExpression expression) + private static bool IsConvertToNullableType(Type targetType) { - var sourceType = expression.Operand.Type; - var targetType = expression.Type; + return targetType.IsNullable(); + } + private static bool IsConvertUnderlyingTypeToEnum(Type sourceType, Type targetType) + { return - targetType.IsEnumOrNullableEnum(out _, out var underlyingType) && - sourceType.IsSameAsOrNullableOf(underlyingType); + targetType.IsEnum(out var underlyingType) && + sourceType == underlyingType; } - private static AggregationExpression TranslateConvertToBaseType(UnaryExpression expression, AggregationExpression operandTranslation) + private static AggregationExpression TranslateConvertToBaseType(UnaryExpression expression, Type sourceType, Type targetType, AggregationExpression sourceTranslation) { - var baseType = expression.Type; - var derivedType = expression.Operand.Type; - var derivedTypeSerializer = operandTranslation.Serializer; - var downcastingSerializer = DowncastingSerializer.Create(baseType, derivedType, derivedTypeSerializer); + var derivedTypeSerializer = sourceTranslation.Serializer; + var downcastingSerializer = DowncastingSerializer.Create(targetType, sourceType, derivedTypeSerializer); - return new AggregationExpression(expression, operandTranslation.Ast, downcastingSerializer); + return new AggregationExpression(expression, sourceTranslation.Ast, downcastingSerializer); } - private static AggregationExpression TranslateConvertToDerivedType(UnaryExpression expression, AggregationExpression operandTranslation) + private static AggregationExpression TranslateConvertToDerivedType(UnaryExpression expression, Type targetType, AggregationExpression sourceTranslation) { - var serializer = BsonSerializer.LookupSerializer(expression.Type); + var serializer = BsonSerializer.LookupSerializer(targetType); - return new AggregationExpression(expression, operandTranslation.Ast, serializer); + return new AggregationExpression(expression, sourceTranslation.Ast, serializer); } - private static AggregationExpression TranslateConvertToBsonValue(TranslationContext context, UnaryExpression expression, Expression operand) + private static AggregationExpression TranslateConvertToBsonValue(UnaryExpression expression, AggregationExpression sourceTranslation) { - // handle double conversions like `(BsonValue)(object)x.Anything` - if (operand is UnaryExpression unaryExpression && - unaryExpression.NodeType == ExpressionType.Convert && - unaryExpression.Type == typeof(object)) - { - operand = unaryExpression.Operand; - } - - var operandTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, operand); - - return new AggregationExpression(expression, operandTranslation.Ast, BsonValueSerializer.Instance); + return new AggregationExpression(expression, sourceTranslation.Ast, BsonValueSerializer.Instance); } - private static AggregationExpression TranslateConvertEnumToEnum(UnaryExpression expression, AggregationExpression operandTranslation) + private static AggregationExpression TranslateConvertEnumToEnum(UnaryExpression expression, Type sourceType, Type targetType, AggregationExpression sourceTranslation) { - var sourceType = expression.Operand.Type; - var targetType = expression.Type; - - if (!sourceType.IsEnumOrNullableEnum(out var sourceEnumType, out _)) + if (!sourceType.IsEnum) { - throw new ExpressionNotSupportedException(expression, because: "source type is not an enum or nullable enum"); + throw new ExpressionNotSupportedException(expression, because: "source type is not an enum"); } - if (!targetType.IsEnumOrNullableEnum(out var targetEnumType, out _)) + if (!targetType.IsEnum) { - throw new ExpressionNotSupportedException(expression, because: "target type is not an enum or nullable enum"); + throw new ExpressionNotSupportedException(expression, because: "target type is not an enum"); } - var sourceSerializer = operandTranslation.Serializer; - IBsonSerializer targetEnumSerializer; - if (targetEnumType == sourceEnumType) + var sourceSerializer = sourceTranslation.Serializer; + if (sourceSerializer is IHasRepresentationSerializer sourceHasRepresentationSerializer && + !SerializationHelper.IsNumericRepresentation(sourceHasRepresentationSerializer.Representation)) { - targetEnumSerializer = sourceSerializer is INullableSerializer sourceNullableSerializer ? - sourceNullableSerializer.ValueSerializer : - sourceSerializer; + throw new ExpressionNotSupportedException(expression, because: "source enum is not represented as a number"); } - else + + var targetSerializer = EnumSerializer.Create(targetType); + return new AggregationExpression(expression, sourceTranslation.Ast, targetSerializer); + } + + private static AggregationExpression TranslateConvertEnumToUnderlyingType(UnaryExpression expression, Type sourceType, Type targetType, AggregationExpression sourceTranslation) + { + var enumSerializer = sourceTranslation.Serializer; + var targetSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer); + return new AggregationExpression(expression, sourceTranslation.Ast, targetSerializer); + } + + private static AggregationExpression TranslateConvertFromNullableType(UnaryExpression expression, Type sourceType, Type targetType, AggregationExpression sourceTranslation) + { + if (sourceType.IsNullable(out var sourceValueType)) { - if (sourceSerializer is IHasRepresentationSerializer sourceHasRepresentationSerializer && - !SerializationHelper.IsNumericRepresentation(sourceHasRepresentationSerializer.Representation)) - { - throw new ExpressionNotSupportedException(expression, because: "source enum is not represented as a number"); - } + var (sourceVarBinding, sourceAst) = AstExpression.UseVarIfNotSimple("source", sourceTranslation.Ast); + var sourceNullableSerializer = (INullableSerializer)sourceTranslation.Serializer; + var sourceValueSerializer = sourceNullableSerializer.ValueSerializer; + var sourceValueAggregationExpression = new AggregationExpression(expression.Operand, sourceAst, sourceValueSerializer); + var convertTranslation = Translate(expression, sourceValueType, targetType, sourceValueAggregationExpression); - targetEnumSerializer = EnumSerializer.Create(targetEnumType); - } + // note: we would have liked to throw a query execution error here if the value is null and the target type is not nullable but there is no way to do that in MQL + // so we just return null instead and the user must check for null themselves if they want to define what happens when the value is null + // but see SERVER-78092 and the proposed $error operator - var targetSerializer = targetType.IsNullable() ? - NullableSerializer.Create(targetEnumSerializer) : - targetEnumSerializer; + var ast = AstExpression.Let( + sourceVarBinding, + AstExpression.Cond(AstExpression.Eq(sourceAst, BsonNull.Value), BsonNull.Value, convertTranslation.Ast)); - return new AggregationExpression(expression, operandTranslation.Ast, targetSerializer); + return new AggregationExpression(expression, ast, convertTranslation.Serializer); + } + + throw new ExpressionNotSupportedException(expression, because: "sourceType is not nullable"); } - private static AggregationExpression TranslateConvertEnumToUnderlyingType(UnaryExpression expression, AggregationExpression operandTranslation) + private static AggregationExpression TranslateConvertToNullableType(UnaryExpression expression, Type sourceType, Type targetType, AggregationExpression sourceTranslation) { - var sourceType = expression.Operand.Type; - var targetType = expression.Type; - - IBsonSerializer enumSerializer; if (sourceType.IsNullable()) { - var nullableSerializer = (INullableSerializer)operandTranslation.Serializer; - enumSerializer = nullableSerializer.ValueSerializer; - } - else - { - enumSerializer = operandTranslation.Serializer; + // ConvertFromNullableType should have been called first + throw new ExpressionNotSupportedException(expression, because: "sourceType is nullable"); } - IBsonSerializer targetSerializer; - var enumUnderlyingTypeSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer); - if (targetType.IsNullable()) - { - targetSerializer = NullableSerializer.Create(enumUnderlyingTypeSerializer); - } - else + if (targetType.IsNullable(out var targetValueType)) { - targetSerializer = enumUnderlyingTypeSerializer; + var convertTranslation = Translate(expression, sourceType, targetValueType, sourceTranslation); + var nullableSerializer = NullableSerializer.Create(convertTranslation.Serializer); + return new AggregationExpression(expression, convertTranslation.Ast, nullableSerializer); } - return new AggregationExpression(expression, operandTranslation.Ast, targetSerializer); + throw new ExpressionNotSupportedException(expression, because: "targetType is not nullable"); } - private static AggregationExpression TranslateConvertUnderlyingTypeToEnum(UnaryExpression expression, AggregationExpression operandTranslation) + private static AggregationExpression TranslateConvertUnderlyingTypeToEnum(UnaryExpression expression, Type sourceType, Type targetType, AggregationExpression sourceTranslation) { - var targetType = expression.Type; - - var valueSerializer = operandTranslation.Serializer; - if (valueSerializer is INullableSerializer nullableSerializer) - { - valueSerializer = nullableSerializer.ValueSerializer; - } + var valueSerializer = sourceTranslation.Serializer; IBsonSerializer targetSerializer; if (valueSerializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) @@ -275,21 +271,10 @@ private static AggregationExpression TranslateConvertUnderlyingTypeToEnum(UnaryE } else { - var enumType = targetType; - if (targetType.IsNullable(out var wrappedType)) - { - enumType = wrappedType; - } - - targetSerializer = EnumSerializer.Create(enumType); - } - - if (targetType.IsNullableEnum()) - { - targetSerializer = NullableSerializer.Create(targetSerializer); + targetSerializer = EnumSerializer.Create(targetType); } - return new AggregationExpression(expression, operandTranslation.Ast, targetSerializer); + return new AggregationExpression(expression, sourceTranslation.Ast, targetSerializer); } } } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4048Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4048Tests.cs index 6cb2e712fb9..402a780bc0d 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4048Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4048Tests.cs @@ -237,7 +237,7 @@ public void IGrouping_Aggregate_with_seed_and_func_and_resultSelector_of_root_sh var expectedStages = new[] { "{ $group : { _id : '$_id', _elements : { $push: '$$ROOT' } } }", - "{ $project : { _id : '$_id', Result : { $let : { vars : { a : { $reduce : { input : '$_elements', initialValue : 0, in : { $add : ['$$value', '$$this.X'] } } } }, in : '$$a' } } } }", + "{ $project : { _id : '$_id', Result : { $reduce : { input : '$_elements', initialValue : 0, in : { $add : ['$$value', '$$this.X'] } } } } }", "{ $sort : { _id : 1 } }" }; AssertStages(stages, expectedStages); @@ -262,7 +262,7 @@ public void IGrouping_Aggregate_with_seed_and_func_and_resultSelector_of_scalar_ var expectedStages = new[] { "{ $group : { _id : '$_id', _elements : { $push: '$X' } } }", - "{ $project : { _id : '$_id', Result : { $let : { vars : { a : { $reduce : { input : '$_elements', initialValue : 0, in : { $add : ['$$value', '$$this'] } } } }, in : '$$a' } } } }", + "{ $project : { _id : '$_id', Result : { $reduce : { input : '$_elements', initialValue : 0, in : { $add : ['$$value', '$$this'] } } } } }", "{ $sort : { _id : 1 } }" }; AssertStages(stages, expectedStages); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5180Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5180Tests.cs new file mode 100644 index 00000000000..69bfcee131f --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5180Tests.cs @@ -0,0 +1,1596 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver.Linq; +using MongoDB.TestHelpers.XunitExtensions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp5180Tests : Linq3IntegrationTest + { + [Theory] + [ParameterAttributeData] + public void Cast_Decimal_to_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal)x.Decimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { Decimal : '$Decimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$Decimal', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Decimal_to_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double)x.Decimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$Decimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$Decimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Decimal_to_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int)x.Decimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$Decimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$Decimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Decimal_to_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long)x.Decimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Decimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$Decimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Decimal_to_nullable_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal?)x.Decimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Decimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$Decimal', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Decimal_to_nullable_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double?)x.Decimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Decimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$Decimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Decimal_to_nullable_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int?)x.Decimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Decimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$Decimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Decimal_to_nullable_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long?)x.Decimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Decimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$Decimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Double_to_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal)x.Double); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$Double', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$Double' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Double_to_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double)x.Double); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { Double : '$Double', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$Double', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Double_to_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int)x.Double); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$Double', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$Double' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Double_to_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long)x.Double); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Double', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$Double' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Double_to_nullable_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal?)x.Double); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Double', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$Double' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Double_to_nullable_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double?)x.Double); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Double', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$Double', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Double_to_nullable_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int?)x.Double); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Double', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$Double' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Double_to_nullable_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long?)x.Double); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Double', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$Double' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Int_to_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal)x.Int); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$Int', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$Int' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Int_to_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double)x.Int); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$Int', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$Int' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Int_to_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int)x.Int); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { Int : '$Int', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$Int', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Int_to_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long)x.Int); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Int', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$Int' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Int_to_nullable_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal?)x.Int); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Int', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$Int' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Int_to_nullable_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double?)x.Int); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Int', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$Int' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Int_to_nullable_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int?)x.Int); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Int', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$Int', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Int_to_nullable_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long?)x.Int); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Int', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$Int' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Long_to_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal)x.Long); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$Long', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$Long' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Long_to_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double)x.Long); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$Long', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$Long' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Long_to_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int)x.Long); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$Long', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$Long' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Long_to_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long)x.Long); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { Long: '$Long', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$Long', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Long_to_nullable_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal?)x.Long); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Long', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$Long' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Long_to_nullable_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double?)x.Long); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Long', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$Long' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Long_to_nullable_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int?)x.Long); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Long', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$Long' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_Long_to_nullable_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long?)x.Long); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$Long', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$Long', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1L); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDecimal_to_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal)x.NullableDecimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { NullableDecimal : '$NullableDecimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$NullableDecimal', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDecimal_to_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double)x.NullableDecimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$NullableDecimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$NullableDecimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDecimal_to_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int)x.NullableDecimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$NullableDecimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$NullableDecimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDecimal_to_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long)x.NullableDecimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableDecimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$NullableDecimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDecimal_to_nullable_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal?)x.NullableDecimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { NullableDecimal: '$NullableDecimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$NullableDecimal', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDecimal_to_nullable_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double?)x.NullableDecimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableDecimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$NullableDecimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDecimal_to_nullable_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int?)x.NullableDecimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableDecimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$NullableDecimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDecimal_to_nullable_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long?)x.NullableDecimal); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableDecimal', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$NullableDecimal' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDouble_to_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal)x.NullableDouble); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$NullableDouble', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$NullableDouble' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDouble_to_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double)x.NullableDouble); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { NullableDouble : '$NullableDouble', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$NullableDouble', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDouble_to_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int)x.NullableDouble); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$NullableDouble', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$NullableDouble' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDouble_to_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long)x.NullableDouble); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableDouble', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$NullableDouble' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDouble_to_nullable_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal?)x.NullableDouble); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableDouble', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$NullableDouble' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDouble_to_nullable_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double?)x.NullableDouble); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { NullableDouble: '$NullableDouble', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$NullableDouble', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDouble_to_nullable_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int?)x.NullableDouble); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableDouble', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$NullableDouble' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableDouble_to_nullable_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long?)x.NullableDouble); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableDouble', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$NullableDouble' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableInt_to_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal)x.NullableInt); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$NullableInt', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$NullableInt' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableInt_to_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double)x.NullableInt); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$NullableInt', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$NullableInt' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableInt_to_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int)x.NullableInt); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { NullableInt : '$NullableInt', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$NullableInt', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableInt_to_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long)x.NullableInt); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableInt', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$NullableInt' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableInt_to_nullable_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal?)x.NullableInt); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableInt', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$NullableInt' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableInt_to_nullable_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double?)x.NullableInt); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableInt', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$NullableInt' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableInt_to_nullable_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int?)x.NullableInt); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { NullableInt: '$NullableInt', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$NullableInt', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableInt_to_nullable_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long?)x.NullableInt); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableInt', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toLong : '$NullableInt' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableLong_to_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal)x.NullableLong); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$NullableLong', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$NullableLong' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableLong_to_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double)x.NullableLong); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$NullableLong', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$NullableLong' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableLong_to_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int)x.NullableLong); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0 : '$NullableLong', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$NullableLong' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableLong_to_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long)x.NullableLong); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { NullableLong: '$NullableLong', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$NullableLong', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableLong_to_nullable_decimal_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (decimal?)x.NullableLong); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableLong', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDecimal : '$NullableLong' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0M); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableLong_to_nullable_double_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (double?)x.NullableLong); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableLong', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toDouble : '$NullableLong' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1.0); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableLong_to_nullable_int_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (int?)x.NullableLong); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { __fld0: '$NullableLong', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : { $toInt : '$NullableLong' }, _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + [Theory] + [ParameterAttributeData] + public void Cast_NullableLong_to_nullable_long_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => (long?)x.NullableLong); + + var stages = Translate(collection, queryable); + if (linqProvider == LinqProvider.V2) + { + AssertStages(stages, "{ $project : { NullableLong: '$NullableLong', _id : 0 } }"); + } + else + { + AssertStages(stages, "{ $project : { _v : '$NullableLong', _id : 0 } }"); + } + + var result = queryable.First(); + result.Should().Be(1); + } + + private IMongoCollection GetCollection(LinqProvider linqProvider) + { + var collection = GetCollection("test", linqProvider); + CreateCollection( + collection, + new C + { + Id = 1, + Decimal = 1.0M, + Double = 1.0, + Int = 1, + Long = 1L, + NullableDecimal = 1.0M, + NullableDouble = 1.0, + NullableInt = 1, + NullableLong = 1L + }); + return collection; + } + + private class C + { + public int Id { get; set; } + [BsonRepresentation(BsonType.Decimal128)] public decimal Decimal { get; set; } + public double Double { get; set; } + public int Int { get; set; } + public long Long { get; set; } + [BsonRepresentation(BsonType.Decimal128)] public decimal? NullableDecimal { get; set; } + public double? NullableDouble { get; set; } + public int? NullableInt { get; set; } + public long? NullableLong { get; set; } + } + } +}