Skip to content

Commit 6b00c3b

Browse files
committed
Replace parallel condition/result vectors with single CaseWhen vector in Expr::Case
The primary motivation for this change is to fix the visitor traversal order for CASE expressions. In SQL, CASE expressions follow a specific syntactic order (e.g., `CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5`), AST visitors now process nodes in the same order as they appear in the source code. The previous implementation, using separate `conditions` and `results` vectors, would visit all conditions first and then all results, which didn't match the source order. The new `CaseWhen` structure ensures visitors process expressions in the correct order: `a,1,2,3,4,5`. A secondary benefit is making invalid states unrepresentable in the type system. The previous implementation using parallel vectors (`conditions` and `results`) made it possible to create invalid CASE expressions where the number of conditions didn't match the number of results. When this happened, the `Display` implementation would silently drop elements from the longer list, potentially masking bugs. The new `CaseWhen` struct couples each condition with its result, making it impossible to create such mismatched states. While this is a breaking change to the AST structure, sqlparser has a history of making such changes when they improve correctness. I don't expect significant downstream breakages, and the benefits of correct visitor ordering and type safety are significant, so I think the trade-off is worthwhile.
1 parent 97f0be6 commit 6b00c3b

File tree

5 files changed

+160
-50
lines changed

5 files changed

+160
-50
lines changed

src/ast/mod.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,22 @@ pub enum CeilFloorKind {
599599
Scale(Value),
600600
}
601601

602+
/// A WHEN clause in a CASE expression containing both
603+
/// the condition and its corresponding result
604+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
605+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
606+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
607+
pub struct CaseWhen {
608+
pub condition: Expr,
609+
pub result: Expr,
610+
}
611+
612+
impl fmt::Display for CaseWhen {
613+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
614+
write!(f, "WHEN {} THEN {}", self.condition, self.result)
615+
}
616+
}
617+
602618
/// An SQL expression of any type.
603619
///
604620
/// # Semantics / Type Checking
@@ -917,8 +933,7 @@ pub enum Expr {
917933
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
918934
Case {
919935
operand: Option<Box<Expr>>,
920-
conditions: Vec<Expr>,
921-
results: Vec<Expr>,
936+
conditions: Vec<CaseWhen>,
922937
else_result: Option<Box<Expr>>,
923938
},
924939
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
@@ -1612,17 +1627,15 @@ impl fmt::Display for Expr {
16121627
Expr::Case {
16131628
operand,
16141629
conditions,
1615-
results,
16161630
else_result,
16171631
} => {
16181632
write!(f, "CASE")?;
16191633
if let Some(operand) = operand {
16201634
write!(f, " {operand}")?;
16211635
}
1622-
for (c, r) in conditions.iter().zip(results) {
1623-
write!(f, " WHEN {c} THEN {r}")?;
1636+
for when in conditions {
1637+
write!(f, " {when}")?;
16241638
}
1625-
16261639
if let Some(else_result) = else_result {
16271640
write!(f, " ELSE {else_result}")?;
16281641
}

src/ast/spans.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,15 +1444,15 @@ impl Spanned for Expr {
14441444
Expr::Case {
14451445
operand,
14461446
conditions,
1447-
results,
14481447
else_result,
14491448
} => union_spans(
14501449
operand
14511450
.as_ref()
14521451
.map(|i| i.span())
14531452
.into_iter()
1454-
.chain(conditions.iter().map(|i| i.span()))
1455-
.chain(results.iter().map(|i| i.span()))
1453+
.chain(conditions.iter().flat_map(|case_when| {
1454+
[case_when.condition.span(), case_when.result.span()]
1455+
}))
14561456
.chain(else_result.as_ref().map(|i| i.span())),
14571457
),
14581458
Expr::Exists { subquery, .. } => subquery.span(),

src/parser/mod.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,11 +2020,11 @@ impl<'a> Parser<'a> {
20202020
self.expect_keyword_is(Keyword::WHEN)?;
20212021
}
20222022
let mut conditions = vec![];
2023-
let mut results = vec![];
20242023
loop {
2025-
conditions.push(self.parse_expr()?);
2024+
let condition = self.parse_expr()?;
20262025
self.expect_keyword_is(Keyword::THEN)?;
2027-
results.push(self.parse_expr()?);
2026+
let result = self.parse_expr()?;
2027+
conditions.push(CaseWhen { condition, result });
20282028
if !self.parse_keyword(Keyword::WHEN) {
20292029
break;
20302030
}
@@ -2038,7 +2038,6 @@ impl<'a> Parser<'a> {
20382038
Ok(Expr::Case {
20392039
operand,
20402040
conditions,
2041-
results,
20422041
else_result,
20432042
})
20442043
}

tests/sqlparser_common.rs

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6539,22 +6539,26 @@ fn parse_searched_case_expr() {
65396539
&Case {
65406540
operand: None,
65416541
conditions: vec![
6542-
IsNull(Box::new(Identifier(Ident::new("bar")))),
6543-
BinaryOp {
6544-
left: Box::new(Identifier(Ident::new("bar"))),
6545-
op: Eq,
6546-
right: Box::new(Expr::Value(number("0"))),
6542+
CaseWhen {
6543+
condition: IsNull(Box::new(Identifier(Ident::new("bar")))),
6544+
result: Expr::Value(Value::SingleQuotedString("null".to_string())),
65476545
},
6548-
BinaryOp {
6549-
left: Box::new(Identifier(Ident::new("bar"))),
6550-
op: GtEq,
6551-
right: Box::new(Expr::Value(number("0"))),
6546+
CaseWhen {
6547+
condition: BinaryOp {
6548+
left: Box::new(Identifier(Ident::new("bar"))),
6549+
op: Eq,
6550+
right: Box::new(Expr::Value(number("0"))),
6551+
},
6552+
result: Expr::Value(Value::SingleQuotedString("=0".to_string())),
6553+
},
6554+
CaseWhen {
6555+
condition: BinaryOp {
6556+
left: Box::new(Identifier(Ident::new("bar"))),
6557+
op: GtEq,
6558+
right: Box::new(Expr::Value(number("0"))),
6559+
},
6560+
result: Expr::Value(Value::SingleQuotedString(">=0".to_string())),
65526561
},
6553-
],
6554-
results: vec![
6555-
Expr::Value(Value::SingleQuotedString("null".to_string())),
6556-
Expr::Value(Value::SingleQuotedString("=0".to_string())),
6557-
Expr::Value(Value::SingleQuotedString(">=0".to_string())),
65586562
],
65596563
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
65606564
"<0".to_string()
@@ -6573,8 +6577,10 @@ fn parse_simple_case_expr() {
65736577
assert_eq!(
65746578
&Case {
65756579
operand: Some(Box::new(Identifier(Ident::new("foo")))),
6576-
conditions: vec![Expr::Value(number("1"))],
6577-
results: vec![Expr::Value(Value::SingleQuotedString("Y".to_string()))],
6580+
conditions: vec![CaseWhen {
6581+
condition: Expr::Value(number("1")),
6582+
result: Expr::Value(Value::SingleQuotedString("Y".to_string())),
6583+
}],
65786584
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
65796585
"N".to_string()
65806586
)))),
@@ -13734,6 +13740,31 @@ fn test_trailing_commas_in_from() {
1373413740
);
1373513741
}
1373613742

13743+
#[test]
13744+
#[cfg(feature = "visitor")]
13745+
fn test_visit_order() {
13746+
let sql = "SELECT CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END";
13747+
let stmt = verified_stmt(sql);
13748+
let mut visited = vec![];
13749+
sqlparser::ast::visit_expressions(&stmt, |expr| {
13750+
visited.push(expr.to_string());
13751+
core::ops::ControlFlow::<()>::Continue(())
13752+
});
13753+
13754+
assert_eq!(
13755+
visited,
13756+
[
13757+
"CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END",
13758+
"a",
13759+
"1",
13760+
"2",
13761+
"3",
13762+
"4",
13763+
"5"
13764+
]
13765+
);
13766+
}
13767+
1373713768
#[test]
1373813769
fn test_lambdas() {
1373913770
let dialects = all_dialects_where(|d| d.supports_lambda_functions());
@@ -13761,28 +13792,30 @@ fn test_lambdas() {
1376113792
body: Box::new(Expr::Case {
1376213793
operand: None,
1376313794
conditions: vec![
13764-
Expr::BinaryOp {
13765-
left: Box::new(Expr::Identifier(Ident::new("p1"))),
13766-
op: BinaryOperator::Eq,
13767-
right: Box::new(Expr::Identifier(Ident::new("p2")))
13795+
CaseWhen {
13796+
condition: Expr::BinaryOp {
13797+
left: Box::new(Expr::Identifier(Ident::new("p1"))),
13798+
op: BinaryOperator::Eq,
13799+
right: Box::new(Expr::Identifier(Ident::new("p2")))
13800+
},
13801+
result: Expr::Value(number("0"))
1376813802
},
13769-
Expr::BinaryOp {
13770-
left: Box::new(call(
13771-
"reverse",
13772-
[Expr::Identifier(Ident::new("p1"))]
13773-
)),
13774-
op: BinaryOperator::Lt,
13775-
right: Box::new(call(
13776-
"reverse",
13777-
[Expr::Identifier(Ident::new("p2"))]
13778-
))
13779-
}
13780-
],
13781-
results: vec![
13782-
Expr::Value(number("0")),
13783-
Expr::UnaryOp {
13784-
op: UnaryOperator::Minus,
13785-
expr: Box::new(Expr::Value(number("1")))
13803+
CaseWhen {
13804+
condition: Expr::BinaryOp {
13805+
left: Box::new(call(
13806+
"reverse",
13807+
[Expr::Identifier(Ident::new("p1"))]
13808+
)),
13809+
op: BinaryOperator::Lt,
13810+
right: Box::new(call(
13811+
"reverse",
13812+
[Expr::Identifier(Ident::new("p2"))]
13813+
))
13814+
},
13815+
result: Expr::UnaryOp {
13816+
op: UnaryOperator::Minus,
13817+
expr: Box::new(Expr::Value(number("1")))
13818+
}
1378613819
}
1378713820
],
1378813821
else_result: Some(Box::new(Expr::Value(number("1"))))

tests/sqlparser_databricks.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,71 @@ fn test_databricks_exists() {
8383
);
8484
}
8585

86+
#[test]
87+
fn test_databricks_lambdas() {
88+
#[rustfmt::skip]
89+
let sql = concat!(
90+
"SELECT array_sort(array('Hello', 'World'), ",
91+
"(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ",
92+
"WHEN reverse(p1) < reverse(p2) THEN -1 ",
93+
"ELSE 1 END)",
94+
);
95+
pretty_assertions::assert_eq!(
96+
SelectItem::UnnamedExpr(call(
97+
"array_sort",
98+
[
99+
call(
100+
"array",
101+
[
102+
Expr::Value(Value::SingleQuotedString("Hello".to_owned())),
103+
Expr::Value(Value::SingleQuotedString("World".to_owned()))
104+
]
105+
),
106+
Expr::Lambda(LambdaFunction {
107+
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
108+
body: Box::new(Expr::Case {
109+
operand: None,
110+
conditions: vec![
111+
CaseWhen {
112+
condition: Expr::BinaryOp {
113+
left: Box::new(Expr::Identifier(Ident::new("p1"))),
114+
op: BinaryOperator::Eq,
115+
right: Box::new(Expr::Identifier(Ident::new("p2")))
116+
},
117+
result: Expr::Value(number("0"))
118+
},
119+
CaseWhen {
120+
condition: Expr::BinaryOp {
121+
left: Box::new(call(
122+
"reverse",
123+
[Expr::Identifier(Ident::new("p1"))]
124+
)),
125+
op: BinaryOperator::Lt,
126+
right: Box::new(call(
127+
"reverse",
128+
[Expr::Identifier(Ident::new("p2"))]
129+
)),
130+
},
131+
result: Expr::UnaryOp {
132+
op: UnaryOperator::Minus,
133+
expr: Box::new(Expr::Value(number("1")))
134+
}
135+
},
136+
],
137+
else_result: Some(Box::new(Expr::Value(number("1"))))
138+
})
139+
})
140+
]
141+
)),
142+
databricks().verified_only_select(sql).projection[0]
143+
);
144+
145+
databricks().verified_expr(
146+
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
147+
);
148+
databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)");
149+
}
150+
86151
#[test]
87152
fn test_values_clause() {
88153
let values = Values {

0 commit comments

Comments
 (0)