Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,57 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition

}

public override SyntaxNode? VisitSwitchExpression(SwitchExpressionSyntax node)
{
// Reverse arms order to start from the default value
var arms = node.Arms.Reverse();

ExpressionSyntax? currentExpression = null;

foreach (var arm in arms)
{
var armExpression = (ExpressionSyntax)Visit(arm.Expression);

// Handle fallback value
if (currentExpression == null)
{
currentExpression = arm.Pattern is DiscardPatternSyntax
? armExpression
: SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression);

continue;
}

// Handle each arm, only if it's a constant expression
if (arm.Pattern is ConstantPatternSyntax constant)
{
ExpressionSyntax expression = SyntaxFactory.BinaryExpression(SyntaxKind.EqualsExpression, (ExpressionSyntax)Visit(node.GoverningExpression), constant.Expression);

// Add the when clause as a AND expression
if (arm.WhenClause != null)
{
expression = SyntaxFactory.BinaryExpression(
SyntaxKind.LogicalAndExpression,
expression,
(ExpressionSyntax)Visit(arm.WhenClause.Condition)
);
}

currentExpression = SyntaxFactory.ConditionalExpression(
expression,
armExpression,
currentExpression
);

continue;
}

throw new InvalidOperationException("Switch expressions rewriting is only supported with constant values");
}

return currentExpression;
}

public override SyntaxNode? VisitMemberBindingExpression(MemberBindingExpressionSyntax node)
{
if (_conditionalAccessExpressionsStack.Count > 0)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// <auto-generated/>
#nullable disable
using EntityFrameworkCore.Projectables;

namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class _Foo_SomeNumber
{
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo, int, int>> Expression()
{
return (global::Foo @this, int input) => input == 1 ? 2 : input == 3 ? 4 : input == 4 && @this.FancyNumber == 12 ? 48 : 1000;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,33 @@ class Foo {
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}

[Fact]
public Task SwitchExpression()
{
var compilation = CreateCompilation(@"
using EntityFrameworkCore.Projectables;

class Foo {
public int? FancyNumber { get; set; }

[Projectable(NullConditionalRewriteSupport = NullConditionalRewriteSupport.Rewrite)]
public int SomeNumber(int input) => input switch {
1 => 2,
3 => 4,
4 when FancyNumber == 12 => 48,
_ => 1000,
};
}
");

var result = RunGenerator(compilation);

Assert.Empty(result.Diagnostics);
Assert.Single(result.GeneratedTrees);

return Verifier.Verify(result.GeneratedTrees[0].ToString());
}

[Fact]
public Task GenericTypes()
{
Expand Down