Skip to content

Commit 04732cf

Browse files
committed
Replace parallel condition/result vectors with single CaseWhen vector in Expr::Case
The primary motivation for this change is to fix the visitor traversal order for CASE expressions. In SQL, CASE expressions follow a specific syntactic order (e.g., `CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5`), AST visitors now process nodes in the same order as they appear in the source code. The previous implementation, using separate `conditions` and `results` vectors, would visit all conditions first and then all results, which didn't match the source order. The new `CaseWhen` structure ensures visitors process expressions in the correct order: `a,1,2,3,4,5`. A secondary benefit is making invalid states unrepresentable in the type system. The previous implementation using parallel vectors (`conditions` and `results`) made it possible to create invalid CASE expressions where the number of conditions didn't match the number of results. When this happened, the `Display` implementation would silently drop elements from the longer list, potentially masking bugs. The new `CaseWhen` struct couples each condition with its result, making it impossible to create such mismatched states. While this is a breaking change to the AST structure, sqlparser has a history of making such changes when they improve correctness. I don't expect significant downstream breakages, and the benefits of correct visitor ordering and type safety are significant, so I think the trade-off is worthwhile.
1 parent 3ace97c commit 04732cf

File tree

5 files changed

+160
-50
lines changed

5 files changed

+160
-50
lines changed

src/ast/mod.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,22 @@ pub enum CeilFloorKind {
600600
Scale(Value),
601601
}
602602

603+
/// A WHEN clause in a CASE expression containing both
604+
/// the condition and its corresponding result
605+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
606+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
607+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
608+
pub struct CaseWhen {
609+
pub condition: Expr,
610+
pub result: Expr,
611+
}
612+
613+
impl fmt::Display for CaseWhen {
614+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
615+
write!(f, "WHEN {} THEN {}", self.condition, self.result)
616+
}
617+
}
618+
603619
/// An SQL expression of any type.
604620
///
605621
/// # Semantics / Type Checking
@@ -918,8 +934,7 @@ pub enum Expr {
918934
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
919935
Case {
920936
operand: Option<Box<Expr>>,
921-
conditions: Vec<Expr>,
922-
results: Vec<Expr>,
937+
conditions: Vec<CaseWhen>,
923938
else_result: Option<Box<Expr>>,
924939
},
925940
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
@@ -1621,17 +1636,15 @@ impl fmt::Display for Expr {
16211636
Expr::Case {
16221637
operand,
16231638
conditions,
1624-
results,
16251639
else_result,
16261640
} => {
16271641
write!(f, "CASE")?;
16281642
if let Some(operand) = operand {
16291643
write!(f, " {operand}")?;
16301644
}
1631-
for (c, r) in conditions.iter().zip(results) {
1632-
write!(f, " WHEN {c} THEN {r}")?;
1645+
for when in conditions {
1646+
write!(f, " {when}")?;
16331647
}
1634-
16351648
if let Some(else_result) = else_result {
16361649
write!(f, " ELSE {else_result}")?;
16371650
}

src/ast/spans.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,15 +1445,15 @@ impl Spanned for Expr {
14451445
Expr::Case {
14461446
operand,
14471447
conditions,
1448-
results,
14491448
else_result,
14501449
} => union_spans(
14511450
operand
14521451
.as_ref()
14531452
.map(|i| i.span())
14541453
.into_iter()
1455-
.chain(conditions.iter().map(|i| i.span()))
1456-
.chain(results.iter().map(|i| i.span()))
1454+
.chain(conditions.iter().flat_map(|case_when| {
1455+
[case_when.condition.span(), case_when.result.span()]
1456+
}))
14571457
.chain(else_result.as_ref().map(|i| i.span())),
14581458
),
14591459
Expr::Exists { subquery, .. } => subquery.span(),

src/parser/mod.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,11 +2065,11 @@ impl<'a> Parser<'a> {
20652065
self.expect_keyword_is(Keyword::WHEN)?;
20662066
}
20672067
let mut conditions = vec![];
2068-
let mut results = vec![];
20692068
loop {
2070-
conditions.push(self.parse_expr()?);
2069+
let condition = self.parse_expr()?;
20712070
self.expect_keyword_is(Keyword::THEN)?;
2072-
results.push(self.parse_expr()?);
2071+
let result = self.parse_expr()?;
2072+
conditions.push(CaseWhen { condition, result });
20732073
if !self.parse_keyword(Keyword::WHEN) {
20742074
break;
20752075
}
@@ -2083,7 +2083,6 @@ impl<'a> Parser<'a> {
20832083
Ok(Expr::Case {
20842084
operand,
20852085
conditions,
2086-
results,
20872086
else_result,
20882087
})
20892088
}

tests/sqlparser_common.rs

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6539,22 +6539,26 @@ fn parse_searched_case_expr() {
65396539
&Case {
65406540
operand: None,
65416541
conditions: vec![
6542-
IsNull(Box::new(Identifier(Ident::new("bar")))),
6543-
BinaryOp {
6544-
left: Box::new(Identifier(Ident::new("bar"))),
6545-
op: Eq,
6546-
right: Box::new(Expr::Value(number("0"))),
6542+
CaseWhen {
6543+
condition: IsNull(Box::new(Identifier(Ident::new("bar")))),
6544+
result: Expr::Value(Value::SingleQuotedString("null".to_string())),
65476545
},
6548-
BinaryOp {
6549-
left: Box::new(Identifier(Ident::new("bar"))),
6550-
op: GtEq,
6551-
right: Box::new(Expr::Value(number("0"))),
6546+
CaseWhen {
6547+
condition: BinaryOp {
6548+
left: Box::new(Identifier(Ident::new("bar"))),
6549+
op: Eq,
6550+
right: Box::new(Expr::Value(number("0"))),
6551+
},
6552+
result: Expr::Value(Value::SingleQuotedString("=0".to_string())),
6553+
},
6554+
CaseWhen {
6555+
condition: BinaryOp {
6556+
left: Box::new(Identifier(Ident::new("bar"))),
6557+
op: GtEq,
6558+
right: Box::new(Expr::Value(number("0"))),
6559+
},
6560+
result: Expr::Value(Value::SingleQuotedString(">=0".to_string())),
65526561
},
6553-
],
6554-
results: vec![
6555-
Expr::Value(Value::SingleQuotedString("null".to_string())),
6556-
Expr::Value(Value::SingleQuotedString("=0".to_string())),
6557-
Expr::Value(Value::SingleQuotedString(">=0".to_string())),
65586562
],
65596563
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
65606564
"<0".to_string()
@@ -6573,8 +6577,10 @@ fn parse_simple_case_expr() {
65736577
assert_eq!(
65746578
&Case {
65756579
operand: Some(Box::new(Identifier(Ident::new("foo")))),
6576-
conditions: vec![Expr::Value(number("1"))],
6577-
results: vec![Expr::Value(Value::SingleQuotedString("Y".to_string()))],
6580+
conditions: vec![CaseWhen {
6581+
condition: Expr::Value(number("1")),
6582+
result: Expr::Value(Value::SingleQuotedString("Y".to_string())),
6583+
}],
65786584
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
65796585
"N".to_string()
65806586
)))),
@@ -13734,6 +13740,31 @@ fn test_trailing_commas_in_from() {
1373413740
);
1373513741
}
1373613742

13743+
#[test]
13744+
#[cfg(feature = "visitor")]
13745+
fn test_visit_order() {
13746+
let sql = "SELECT CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END";
13747+
let stmt = verified_stmt(sql);
13748+
let mut visited = vec![];
13749+
sqlparser::ast::visit_expressions(&stmt, |expr| {
13750+
visited.push(expr.to_string());
13751+
core::ops::ControlFlow::<()>::Continue(())
13752+
});
13753+
13754+
assert_eq!(
13755+
visited,
13756+
[
13757+
"CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END",
13758+
"a",
13759+
"1",
13760+
"2",
13761+
"3",
13762+
"4",
13763+
"5"
13764+
]
13765+
);
13766+
}
13767+
1373713768
#[test]
1373813769
fn test_lambdas() {
1373913770
let dialects = all_dialects_where(|d| d.supports_lambda_functions());
@@ -13761,28 +13792,30 @@ fn test_lambdas() {
1376113792
body: Box::new(Expr::Case {
1376213793
operand: None,
1376313794
conditions: vec![
13764-
Expr::BinaryOp {
13765-
left: Box::new(Expr::Identifier(Ident::new("p1"))),
13766-
op: BinaryOperator::Eq,
13767-
right: Box::new(Expr::Identifier(Ident::new("p2")))
13795+
CaseWhen {
13796+
condition: Expr::BinaryOp {
13797+
left: Box::new(Expr::Identifier(Ident::new("p1"))),
13798+
op: BinaryOperator::Eq,
13799+
right: Box::new(Expr::Identifier(Ident::new("p2")))
13800+
},
13801+
result: Expr::Value(number("0"))
1376813802
},
13769-
Expr::BinaryOp {
13770-
left: Box::new(call(
13771-
"reverse",
13772-
[Expr::Identifier(Ident::new("p1"))]
13773-
)),
13774-
op: BinaryOperator::Lt,
13775-
right: Box::new(call(
13776-
"reverse",
13777-
[Expr::Identifier(Ident::new("p2"))]
13778-
))
13779-
}
13780-
],
13781-
results: vec![
13782-
Expr::Value(number("0")),
13783-
Expr::UnaryOp {
13784-
op: UnaryOperator::Minus,
13785-
expr: Box::new(Expr::Value(number("1")))
13803+
CaseWhen {
13804+
condition: Expr::BinaryOp {
13805+
left: Box::new(call(
13806+
"reverse",
13807+
[Expr::Identifier(Ident::new("p1"))]
13808+
)),
13809+
op: BinaryOperator::Lt,
13810+
right: Box::new(call(
13811+
"reverse",
13812+
[Expr::Identifier(Ident::new("p2"))]
13813+
))
13814+
},
13815+
result: Expr::UnaryOp {
13816+
op: UnaryOperator::Minus,
13817+
expr: Box::new(Expr::Value(number("1")))
13818+
}
1378613819
}
1378713820
],
1378813821
else_result: Some(Box::new(Expr::Value(number("1"))))

tests/sqlparser_databricks.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,71 @@ fn test_databricks_exists() {
8383
);
8484
}
8585

86+
#[test]
87+
fn test_databricks_lambdas() {
88+
#[rustfmt::skip]
89+
let sql = concat!(
90+
"SELECT array_sort(array('Hello', 'World'), ",
91+
"(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ",
92+
"WHEN reverse(p1) < reverse(p2) THEN -1 ",
93+
"ELSE 1 END)",
94+
);
95+
pretty_assertions::assert_eq!(
96+
SelectItem::UnnamedExpr(call(
97+
"array_sort",
98+
[
99+
call(
100+
"array",
101+
[
102+
Expr::Value(Value::SingleQuotedString("Hello".to_owned())),
103+
Expr::Value(Value::SingleQuotedString("World".to_owned()))
104+
]
105+
),
106+
Expr::Lambda(LambdaFunction {
107+
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
108+
body: Box::new(Expr::Case {
109+
operand: None,
110+
conditions: vec![
111+
CaseWhen {
112+
condition: Expr::BinaryOp {
113+
left: Box::new(Expr::Identifier(Ident::new("p1"))),
114+
op: BinaryOperator::Eq,
115+
right: Box::new(Expr::Identifier(Ident::new("p2")))
116+
},
117+
result: Expr::Value(number("0"))
118+
},
119+
CaseWhen {
120+
condition: Expr::BinaryOp {
121+
left: Box::new(call(
122+
"reverse",
123+
[Expr::Identifier(Ident::new("p1"))]
124+
)),
125+
op: BinaryOperator::Lt,
126+
right: Box::new(call(
127+
"reverse",
128+
[Expr::Identifier(Ident::new("p2"))]
129+
)),
130+
},
131+
result: Expr::UnaryOp {
132+
op: UnaryOperator::Minus,
133+
expr: Box::new(Expr::Value(number("1")))
134+
}
135+
},
136+
],
137+
else_result: Some(Box::new(Expr::Value(number("1"))))
138+
})
139+
})
140+
]
141+
)),
142+
databricks().verified_only_select(sql).projection[0]
143+
);
144+
145+
databricks().verified_expr(
146+
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
147+
);
148+
databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)");
149+
}
150+
86151
#[test]
87152
fn test_values_clause() {
88153
let values = Values {

0 commit comments

Comments
 (0)