Skip to content

Commit b011ee2

Browse files
authored
feat(bigquery): Add support for side & kind on set operators (#4959)
* feat(bigquery): Support set operations * PR Feedback 1
1 parent 44b955b commit b011ee2

File tree

7 files changed

+228
-50
lines changed

7 files changed

+228
-50
lines changed

sqlglot/expressions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3370,6 +3370,9 @@ class SetOperation(Query):
33703370
"expression": True,
33713371
"distinct": False,
33723372
"by_name": False,
3373+
"side": False,
3374+
"kind": False,
3375+
"on": False,
33733376
**QUERY_MODIFIERS,
33743377
}
33753378

@@ -3408,6 +3411,14 @@ def left(self) -> Query:
34083411
def right(self) -> Query:
34093412
return self.expression
34103413

3414+
@property
3415+
def kind(self) -> str:
3416+
return self.text("kind").upper()
3417+
3418+
@property
3419+
def side(self) -> str:
3420+
return self.text("side").upper()
3421+
34113422

34123423
class Union(SetOperation):
34133424
pass

sqlglot/generator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,12 +1445,18 @@ def set_operation(self, expression: exp.SetOperation) -> str:
14451445
self.unsupported(f"{op_name} requires DISTINCT or ALL to be specified")
14461446

14471447
if distinct is default_distinct:
1448-
kind = ""
1448+
distinct_or_all = ""
14491449
else:
1450-
kind = " DISTINCT" if distinct else " ALL"
1450+
distinct_or_all = " DISTINCT" if distinct else " ALL"
1451+
1452+
side_kind = " ".join(filter(None, [expression.side, expression.kind]))
1453+
side_kind = f"{side_kind} " if side_kind else ""
14511454

14521455
by_name = " BY NAME" if expression.args.get("by_name") else ""
1453-
return f"{op_name}{kind}{by_name}"
1456+
on = self.expressions(expression, key="on", flat=True)
1457+
on = f" ON ({on})" if on else ""
1458+
1459+
return f"{side_kind}{op_name}{distinct_or_all}{by_name}{on}"
14541460

14551461
def set_operations(self, expression: exp.SetOperation) -> str:
14561462
if not self.SET_OP_MODIFIERS:

sqlglot/optimizer/pushdown_projections.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,31 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
4949
parent_selections = {SELECT_ALL}
5050

5151
if isinstance(scope.expression, exp.SetOperation):
52-
left, right = scope.union_scopes
53-
if len(left.expression.selects) != len(right.expression.selects):
54-
scope_sql = scope.expression.sql()
55-
raise OptimizeError(f"Invalid set operation due to column mismatch: {scope_sql}.")
56-
57-
referenced_columns[left] = parent_selections
58-
59-
if any(select.is_star for select in right.expression.selects):
60-
referenced_columns[right] = parent_selections
61-
elif not any(select.is_star for select in left.expression.selects):
62-
if scope.expression.args.get("by_name"):
63-
referenced_columns[right] = referenced_columns[left]
64-
else:
65-
referenced_columns[right] = [
66-
right.expression.selects[i].alias_or_name
67-
for i, select in enumerate(left.expression.selects)
68-
if SELECT_ALL in parent_selections
69-
or select.alias_or_name in parent_selections
70-
]
52+
set_op = scope.expression
53+
if not (set_op.kind or set_op.side):
54+
# Do not optimize this set operation if it's using the BigQuery specific
55+
# kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
56+
left, right = scope.union_scopes
57+
if len(left.expression.selects) != len(right.expression.selects):
58+
scope_sql = scope.expression.sql()
59+
raise OptimizeError(
60+
f"Invalid set operation due to column mismatch: {scope_sql}."
61+
)
62+
63+
referenced_columns[left] = parent_selections
64+
65+
if any(select.is_star for select in right.expression.selects):
66+
referenced_columns[right] = parent_selections
67+
elif not any(select.is_star for select in left.expression.selects):
68+
if scope.expression.args.get("by_name"):
69+
referenced_columns[right] = referenced_columns[left]
70+
else:
71+
referenced_columns[right] = [
72+
right.expression.selects[i].alias_or_name
73+
for i, select in enumerate(left.expression.selects)
74+
if SELECT_ALL in parent_selections
75+
or select.alias_or_name in parent_selections
76+
]
7177

7278
if isinstance(scope.expression, exp.Select):
7379
if remove_unused_selections:

sqlglot/optimizer/qualify_columns.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,32 @@ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequenc
916916
if source.expression.is_type(exp.DataType.Type.STRUCT):
917917
for k in source.expression.type.expressions: # type: ignore
918918
columns.append(k.name)
919+
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
920+
set_op = source.expression
921+
922+
# BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
923+
on_column_list = set_op.args.get("on")
924+
925+
if on_column_list:
926+
# The resulting columns are the columns in the ON clause:
927+
# {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
928+
columns = [col.name for col in on_column_list]
929+
elif set_op.side or set_op.kind:
930+
side = set_op.side
931+
kind = set_op.kind
932+
933+
left = set_op.left.named_selects
934+
right = set_op.right.named_selects
935+
936+
# We use dict.fromkeys to deduplicate keys and maintain insertion order
937+
if side == "LEFT":
938+
columns = left
939+
elif side == "FULL":
940+
columns = list(dict.fromkeys(left + right))
941+
elif kind == "INNER":
942+
columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
943+
else:
944+
columns = set_op.named_selects
919945
else:
920946
columns = source.expression.named_selects
921947

sqlglot/parser.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4596,39 +4596,69 @@ def _parse_locks(self) -> t.List[exp.Lock]:
45964596

45974597
return locks
45984598

4599-
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
4600-
while this and self._match_set(self.SET_OPERATIONS):
4601-
token_type = self._prev.token_type
4599+
def parse_set_operation(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
4600+
start = self._index
4601+
_, side_token, kind_token = self._parse_join_parts()
46024602

4603-
if token_type == TokenType.UNION:
4604-
operation: t.Type[exp.SetOperation] = exp.Union
4605-
elif token_type == TokenType.EXCEPT:
4606-
operation = exp.Except
4607-
else:
4608-
operation = exp.Intersect
4603+
side = side_token.text if side_token else None
4604+
kind = kind_token.text if kind_token else None
46094605

4610-
comments = self._prev.comments
4606+
if not self._match_set(self.SET_OPERATIONS):
4607+
self._retreat(start)
4608+
return None
46114609

4612-
if self._match(TokenType.DISTINCT):
4613-
distinct: t.Optional[bool] = True
4614-
elif self._match(TokenType.ALL):
4615-
distinct = False
4616-
else:
4617-
distinct = self.dialect.SET_OP_DISTINCT_BY_DEFAULT[operation]
4618-
if distinct is None:
4619-
self.raise_error(f"Expected DISTINCT or ALL for {operation.__name__}")
4610+
token_type = self._prev.token_type
46204611

4621-
by_name = self._match_text_seq("BY", "NAME")
4622-
expression = self._parse_select(nested=True, parse_set_operation=False)
4612+
if token_type == TokenType.UNION:
4613+
operation: t.Type[exp.SetOperation] = exp.Union
4614+
elif token_type == TokenType.EXCEPT:
4615+
operation = exp.Except
4616+
else:
4617+
operation = exp.Intersect
46234618

4624-
this = self.expression(
4625-
operation,
4626-
comments=comments,
4627-
this=this,
4628-
distinct=distinct,
4629-
by_name=by_name,
4630-
expression=expression,
4631-
)
4619+
comments = self._prev.comments
4620+
4621+
if self._match(TokenType.DISTINCT):
4622+
distinct: t.Optional[bool] = True
4623+
elif self._match(TokenType.ALL):
4624+
distinct = False
4625+
else:
4626+
distinct = self.dialect.SET_OP_DISTINCT_BY_DEFAULT[operation]
4627+
if distinct is None:
4628+
self.raise_error(f"Expected DISTINCT or ALL for {operation.__name__}")
4629+
4630+
by_name = self._match_text_seq("BY", "NAME") or self._match_text_seq(
4631+
"STRICT", "CORRESPONDING"
4632+
)
4633+
if self._match_text_seq("CORRESPONDING"):
4634+
by_name = True
4635+
if not side and not kind:
4636+
kind = "INNER"
4637+
4638+
on_column_list = None
4639+
if by_name and self._match_texts(("ON", "BY")):
4640+
on_column_list = self._parse_wrapped_csv(self._parse_column)
4641+
4642+
expression = self._parse_select(nested=True, parse_set_operation=False)
4643+
4644+
return self.expression(
4645+
operation,
4646+
comments=comments,
4647+
this=this,
4648+
distinct=distinct,
4649+
by_name=by_name,
4650+
expression=expression,
4651+
side=side,
4652+
kind=kind,
4653+
on=on_column_list,
4654+
)
4655+
4656+
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
4657+
while True:
4658+
setop = self.parse_set_operation(this)
4659+
if not setop:
4660+
break
4661+
this = setop
46324662

46334663
if isinstance(this, exp.SetOperation) and self.MODIFIERS_ATTACHED_TO_SET_OP:
46344664
expression = this.expression

tests/dialects/test_bigquery.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,3 +2382,43 @@ def test_annotate_timestamps(self):
23822382

23832383
for select in annotated.selects:
23842384
self.assertEqual(select.type.sql("bigquery"), "TIMESTAMP")
2385+
2386+
def test_set_operations(self):
2387+
self.validate_identity("SELECT 1 AS foo INNER UNION ALL SELECT 3 AS foo, 4 AS bar")
2388+
2389+
for side in ("", " LEFT", " FULL"):
2390+
for kind in ("", " OUTER"):
2391+
for name in (
2392+
"",
2393+
" BY NAME",
2394+
" BY NAME ON (foo, bar)",
2395+
):
2396+
with self.subTest(f"Testing {side} {kind} {name} in test_set_operations"):
2397+
self.validate_identity(
2398+
f"SELECT 1 AS foo{side}{kind} UNION ALL{name} SELECT 3 AS foo, 4 AS bar",
2399+
)
2400+
2401+
self.validate_identity(
2402+
"SELECT 1 AS x UNION ALL CORRESPONDING SELECT 2 AS x",
2403+
"SELECT 1 AS x INNER UNION ALL BY NAME SELECT 2 AS x",
2404+
)
2405+
2406+
self.validate_identity(
2407+
"SELECT 1 AS x UNION ALL CORRESPONDING BY (foo, bar) SELECT 2 AS x",
2408+
"SELECT 1 AS x INNER UNION ALL BY NAME ON (foo, bar) SELECT 2 AS x",
2409+
)
2410+
2411+
self.validate_identity(
2412+
"SELECT 1 AS x LEFT UNION ALL CORRESPONDING SELECT 2 AS x",
2413+
"SELECT 1 AS x LEFT UNION ALL BY NAME SELECT 2 AS x",
2414+
)
2415+
2416+
self.validate_identity(
2417+
"SELECT 1 AS x UNION ALL STRICT CORRESPONDING SELECT 2 AS x",
2418+
"SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x",
2419+
)
2420+
2421+
self.validate_identity(
2422+
"SELECT 1 AS x UNION ALL STRICT CORRESPONDING BY (foo, bar) SELECT 2 AS x",
2423+
"SELECT 1 AS x UNION ALL BY NAME ON (foo, bar) SELECT 2 AS x",
2424+
)

tests/fixtures/optimizer/qualify_columns.sql

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,65 @@ SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM
325325
((select a from x where a < 1)) UNION ((select a from x where a > 2));
326326
((SELECT x.a AS a FROM x AS x WHERE x.a < 1)) UNION ((SELECT x.a AS a FROM x AS x WHERE x.a > 2));
327327

328+
329+
# dialect: bigquery
330+
# execute: false
331+
SELECT * FROM (SELECT 1 AS foo, 2 AS bar INNER UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz);
332+
SELECT _q_0.bar AS bar FROM (SELECT 1 AS foo, 2 AS bar INNER UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) AS _q_0;
333+
334+
# dialect: bigquery
335+
# execute: false
336+
SELECT * FROM (SELECT 1 AS foo, 2 AS bar UNION ALL CORRESPONDING SELECT 3 AS bar, 4 AS baz);
337+
SELECT _q_0.bar AS bar FROM (SELECT 1 AS foo, 2 AS bar INNER UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) AS _q_0;
338+
339+
# dialect: bigquery
340+
# execute: false
341+
SELECT * FROM (SELECT 1 AS foo, 2 AS bar LEFT UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz);
342+
SELECT _q_0.foo AS foo, _q_0.bar AS bar FROM (SELECT 1 AS foo, 2 AS bar LEFT UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) AS _q_0;
343+
344+
# dialect: bigquery
345+
# execute: false
346+
SELECT * FROM (SELECT 1 AS foo, 2 AS bar FULL UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz);
347+
SELECT _q_0.foo AS foo, _q_0.bar AS bar, _q_0.baz AS baz FROM (SELECT 1 AS foo, 2 AS bar FULL UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) AS _q_0;
348+
349+
350+
# dialect: bigquery
351+
# execute: false
352+
SELECT * FROM (SELECT 1 AS foo, 2 AS bar LEFT UNION ALL CORRESPONDING SELECT 3 AS bar, 4 AS baz);
353+
SELECT _q_0.foo AS foo, _q_0.bar AS bar FROM (SELECT 1 AS foo, 2 AS bar LEFT UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) AS _q_0;
354+
355+
356+
# dialect: bigquery
357+
# execute: false
358+
SELECT * FROM (SELECT 1 AS foo, 2 AS bar FULL UNION ALL CORRESPONDING SELECT 3 AS bar, 4 AS baz);
359+
SELECT _q_0.foo AS foo, _q_0.bar AS bar, _q_0.baz AS baz FROM (SELECT 1 AS foo, 2 AS bar FULL UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) AS _q_0;
360+
361+
# dialect: bigquery
362+
# execute: false
363+
SELECT * FROM (SELECT 1 AS foo, 2 AS bar FULL UNION ALL CORRESPONDING BY (foo, bar) SELECT 3 AS bar, 4 AS baz);
364+
SELECT _q_0.foo AS foo, _q_0.bar AS bar FROM (SELECT 1 AS foo, 2 AS bar FULL UNION ALL BY NAME ON (foo, bar) SELECT 3 AS bar, 4 AS baz) AS _q_0;
365+
366+
367+
# dialect: bigquery
368+
# execute: false
369+
SELECT * FROM (SELECT 1 AS foo, 2 AS bar FULL UNION ALL BY NAME ON (foo, bar) SELECT 3 AS bar, 4 AS baz);
370+
SELECT _q_0.foo AS foo, _q_0.bar AS bar FROM (SELECT 1 AS foo, 2 AS bar FULL UNION ALL BY NAME ON (foo, bar) SELECT 3 AS bar, 4 AS baz) AS _q_0;
371+
372+
# dialect: bigquery
373+
# execute: false
374+
SELECT * FROM ((SELECT 1 AS foo, 2 AS bar LEFT UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) LEFT UNION ALL BY NAME ON (bar) SELECT 3 AS foo, 4 AS bar);
375+
SELECT _q_0.bar AS bar FROM ((SELECT 1 AS foo, 2 AS bar LEFT UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) LEFT UNION ALL BY NAME ON (bar) SELECT 3 AS foo, 4 AS bar) AS _q_0;
376+
377+
# dialect: bigquery
378+
# execute: false
379+
SELECT * FROM ((SELECT 1 AS foo, 2 AS bar LEFT UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) FULL UNION ALL BY NAME ON (foo, qux) SELECT 3 AS qux, 4 AS bar);
380+
SELECT _q_0.foo AS foo, _q_0.qux AS qux FROM ((SELECT 1 AS foo, 2 AS bar LEFT UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) FULL UNION ALL BY NAME ON (foo, qux) SELECT 3 AS qux, 4 AS bar) AS _q_0;
381+
382+
# dialect: bigquery
383+
# execute: false
384+
SELECT * FROM (((SELECT 1 AS foo, 2 AS bar LEFT UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) FULL UNION ALL BY NAME ON (foo, qux) SELECT 3 AS qux, 4 AS bar) INNER UNION ALL BY NAME ON (foo) SELECT 6 AS foo);
385+
SELECT _q_0.foo AS foo FROM (((SELECT 1 AS foo, 2 AS bar LEFT UNION ALL BY NAME SELECT 3 AS bar, 4 AS baz) FULL UNION ALL BY NAME ON (foo, qux) SELECT 3 AS qux, 4 AS bar) INNER UNION ALL BY NAME ON (foo) SELECT 6 AS foo) AS _q_0;
386+
328387
--------------------------------------
329388
-- Subqueries
330389
--------------------------------------

0 commit comments

Comments
 (0)