Skip to content

Commit 865a77d

Browse files
committed
refactor: add table aliases to prevent ambiguity in SQLGlotIR
1 parent 36261f5 commit 865a77d

File tree

285 files changed

+414
-328
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

285 files changed

+414
-328
lines changed

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 113 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -134,24 +134,50 @@ def from_table(
134134
"""
135135
version = (
136136
sge.Version(
137-
this="TIMESTAMP",
138-
expression=sge.Literal(this=system_time.isoformat(), is_string=True),
137+
this=sge.Identifier(this="SYSTEM_TIME", quoted=False),
138+
expression=sge.Literal.string(system_time.isoformat()),
139139
kind="AS OF",
140140
)
141141
if system_time
142142
else None
143143
)
144+
table_alias = next(uid_gen.get_uid_stream("bft_"))
144145
table_expr = sge.Table(
145146
this=sg.to_identifier(table_id, quoted=cls.quoted),
146147
db=sg.to_identifier(dataset_id, quoted=cls.quoted),
147148
catalog=sg.to_identifier(project_id, quoted=cls.quoted),
148149
version=version,
150+
alias=sge.Identifier(this=table_alias, quoted=cls.quoted),
149151
)
150152
if sql_predicate:
151-
select_expr = sge.Select().select(sge.Star()).from_(table_expr)
152-
select_expr = select_expr.where(
153-
sg.parse_one(sql_predicate, dialect=cls.dialect), append=False
153+
table_alias = sge.to_identifier(
154+
next(uid_gen.get_uid_stream("bft_")), quoted=cls.quoted
154155
)
156+
# WORKAROUND: SQLGlot renders Table + version + alias in wrong order for BigQuery.
157+
# Wrapping in a subquery ensures valid SQL: (SELECT * FROM table FOR SYSTEM_TIME AS OF ...) AS alias
158+
if version:
159+
from_item = (
160+
sge.Select()
161+
.select(sge.Star())
162+
.from_(table_expr)
163+
.subquery(alias=table_alias)
164+
)
165+
else:
166+
from_item = sge.Alias(this=table_expr, alias=table_alias)
167+
168+
select_expr = (
169+
sge.Select()
170+
.select(sge.Column(this=sge.Star(), table=table_alias))
171+
.from_(from_item)
172+
)
173+
174+
predicate_expr = sg.parse_one(sql_predicate, dialect=cls.dialect)
175+
predicate_expr = predicate_expr.transform(
176+
lambda e: sge.Column(this=e.this, table=table_alias)
177+
if isinstance(e, sge.Column) and not e.table
178+
else e
179+
)
180+
select_expr = select_expr.where(predicate_expr, append=False)
155181
return cls(expr=select_expr, uid_gen=uid_gen)
156182

157183
return cls(expr=table_expr, uid_gen=uid_gen)
@@ -165,35 +191,66 @@ def select(
165191
) -> SQLGlotIR:
166192
# TODO: Explicitly insert CTEs into plan
167193
if isinstance(self.expr, sge.Select):
168-
new_expr, _ = self._select_to_cte()
194+
new_expr, table_alias = self._select_to_cte()
169195
else:
170-
new_expr = sge.Select().from_(self.expr)
196+
table_alias = sge.to_identifier(
197+
next(self.uid_gen.get_uid_stream("bft_")), quoted=self.quoted
198+
)
199+
# WORKAROUND: SQLGlot renders Table + version + alias in wrong order for BigQuery.
200+
# Wrapping in a subquery ensures valid SQL: (SELECT * FROM table FOR SYSTEM_TIME AS OF ...) AS alias
201+
if isinstance(self.expr, sge.Table) and self.expr.args.get("version"):
202+
from_item = (
203+
sge.Select()
204+
.select(sge.Star())
205+
.from_(self.expr)
206+
.subquery(alias=table_alias)
207+
)
208+
else:
209+
from_item = sge.Alias(this=self.expr, alias=table_alias)
210+
211+
new_expr = sge.Select().from_(from_item)
171212

172213
if len(sorting) > 0:
173-
new_expr = new_expr.order_by(*sorting)
214+
new_expr = new_expr.order_by(
215+
*[self._qualify(sort, table_alias) for sort in sorting]
216+
)
174217

175218
if len(selections) > 0:
176219
to_select = [
177220
sge.Alias(
178-
this=expr,
221+
this=self._qualify(expr, table_alias),
179222
alias=sge.to_identifier(id, quoted=self.quoted),
180223
)
181224
if expr.alias_or_name != id
182-
else expr
225+
else self._qualify(expr, table_alias)
183226
for id, expr in selections
184227
]
185228
new_expr = new_expr.select(*to_select, append=False)
186229
else:
187-
new_expr = new_expr.select(sge.Star(), append=False)
230+
new_expr = new_expr.select(
231+
sge.Column(this=sge.Star(), table=table_alias), append=False
232+
)
188233

189234
if len(predicates) > 0:
190-
condition = _and(predicates)
235+
condition = _and(
236+
tuple(self._qualify(predicate, table_alias) for predicate in predicates)
237+
)
191238
new_expr = new_expr.where(condition, append=False)
192239
if limit is not None:
193240
new_expr = new_expr.limit(limit)
194241

195242
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
196243

244+
def _qualify(
245+
self, expr: sge.Expression, table_alias: sge.Identifier
246+
) -> sge.Expression:
247+
def _transform(e):
248+
if isinstance(e, sge.Column) and not e.table:
249+
return sge.Column(this=e.this, table=table_alias)
250+
return e
251+
252+
return expr.transform(_transform)
253+
197254
@classmethod
198255
def from_query_string(
199256
cls,
@@ -210,7 +267,11 @@ def from_query_string(
210267
this=query_string,
211268
alias=cte_name,
212269
)
213-
select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
270+
select_expr = (
271+
sge.Select()
272+
.select(sge.Column(this=sge.Star(), table=cte_name))
273+
.from_(sge.Table(this=cte_name))
274+
)
214275
select_expr = _set_query_ctes(select_expr, [cte])
215276
return cls(expr=select_expr, uid_gen=uid_gen)
216277

@@ -276,9 +337,23 @@ def join(
276337
right_select, right_ctes = _pop_query_ctes(right_select)
277338
merged_ctes = _merge_ctes(left_ctes, right_ctes)
278339

340+
# Qualify join conditions
341+
qualified_conditions = tuple(
342+
(
343+
typed_expr.TypedExpr(
344+
self._qualify(left.expr, left_cte_name), left.dtype
345+
),
346+
typed_expr.TypedExpr(
347+
self._qualify(right.expr, right_cte_name), right.dtype
348+
),
349+
)
350+
for left, right in conditions
351+
)
352+
279353
join_on = _and(
280354
tuple(
281-
_join_condition(left, right, joins_nulls) for left, right in conditions
355+
_join_condition(left, right, joins_nulls)
356+
for left, right in qualified_conditions
282357
)
283358
)
284359

@@ -310,7 +385,7 @@ def isin_join(
310385
merged_ctes = _merge_ctes(left_ctes, right_ctes)
311386

312387
left_condition = typed_expr.TypedExpr(
313-
sge.Column(this=conditions[0].expr, table=left_cte_name),
388+
self._qualify(conditions[0].expr, left_cte_name),
314389
conditions[0].dtype,
315390
)
316391

@@ -320,7 +395,7 @@ def isin_join(
320395
next(self.uid_gen.get_uid_stream("bft_")), quoted=self.quoted
321396
)
322397
right_condition = typed_expr.TypedExpr(
323-
sge.Column(this=conditions[1].expr, table=right_table_name),
398+
self._qualify(conditions[1].expr, right_table_name),
324399
conditions[1].dtype,
325400
)
326401
new_column = sge.Exists(
@@ -371,7 +446,8 @@ def sample(self, fraction: float) -> SQLGlotIR:
371446
expression=_literal(fraction, dtypes.FLOAT_DTYPE),
372447
)
373448

374-
new_expr = self._select_to_cte()[0].where(condition, append=False)
449+
new_expr, table_alias = self._select_to_cte()
450+
new_expr = new_expr.where(condition, append=False)
375451
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
376452

377453
def aggregate(
@@ -387,23 +463,27 @@ def aggregate(
387463
by_cols: column expressions for aggregation
388464
dropna_cols: columns whether null keys should be dropped
389465
"""
390-
aggregations_expr = [
466+
new_expr, table_alias = self._select_to_cte()
467+
468+
qualified_aggregations_expr = [
391469
sge.Alias(
392-
this=expr,
470+
this=self._qualify(expr, table_alias),
393471
alias=sge.to_identifier(id, quoted=self.quoted),
394472
)
395473
for id, expr in aggregations
396474
]
397475

398-
new_expr, _ = self._select_to_cte()
399-
new_expr = new_expr.group_by(*by_cols).select(
400-
*[*by_cols, *aggregations_expr], append=False
476+
qualified_by_cols = [self._qualify(col, table_alias) for col in by_cols]
477+
qualified_dropna_cols = [self._qualify(col, table_alias) for col in dropna_cols]
478+
479+
new_expr = new_expr.group_by(*qualified_by_cols).select(
480+
*[*qualified_by_cols, *qualified_aggregations_expr], append=False
401481
)
402482

403483
condition = _and(
404484
tuple(
405485
sg.not_(sge.Is(this=drop_col, expression=sge.Null()))
406-
for drop_col in dropna_cols
486+
for drop_col in qualified_dropna_cols
407487
)
408488
)
409489
if condition is not None:
@@ -496,14 +576,16 @@ def _explode_single_column(
496576
unnested_column_alias = sge.to_identifier(
497577
next(self.uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted
498578
)
579+
580+
new_expr, table_alias = self._select_to_cte()
581+
499582
unnest_expr = sge.Unnest(
500-
expressions=[column],
583+
expressions=[sge.Column(this=column, table=table_alias)],
501584
alias=sge.TableAlias(columns=[unnested_column_alias]),
502585
offset=offset,
503586
)
504587
selection = sge.Star(replace=[unnested_column_alias.as_(column)])
505588

506-
new_expr, _ = self._select_to_cte()
507589
# Use LEFT JOIN to preserve rows when unnesting empty arrays.
508590
new_expr = new_expr.select(selection, append=False).join(
509591
unnest_expr, join_type="LEFT"
@@ -524,10 +606,12 @@ def _explode_multiple_columns(
524606
for column_name in column_names
525607
]
526608

609+
new_expr, table_alias = self._select_to_cte()
610+
527611
# If there are multiple columns, we need to unnest by zipping the arrays:
528612
# https://cloud.google.com/bigquery/docs/arrays#zipping_arrays
529613
column_lengths = [
530-
sge.func("ARRAY_LENGTH", sge.to_identifier(column, quoted=self.quoted)) - 1
614+
sge.func("ARRAY_LENGTH", sge.Column(this=column, table=table_alias)) - 1
531615
for column in columns
532616
]
533617
generate_array = sge.func(
@@ -554,7 +638,6 @@ def _explode_multiple_columns(
554638
for column in columns
555639
]
556640
)
557-
new_expr, _ = self._select_to_cte()
558641
# Use LEFT JOIN to preserve rows when unnesting empty arrays.
559642
new_expr = new_expr.select(selection, append=False).join(
560643
unnest_expr, join_type="LEFT"
@@ -590,7 +673,9 @@ def _select_to_cte(self) -> tuple[sge.Select, sge.Identifier]:
590673
alias=cte_name,
591674
)
592675
new_select_expr = (
593-
sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
676+
sge.Select()
677+
.select(sge.Column(this=sge.Star(), table=cte_name))
678+
.from_(sge.Table(this=cte_name))
594679
)
595680
new_select_expr = _set_query_ctes(new_select_expr, [*existing_ctes, new_cte])
596681
return new_select_expr, cte_name

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_binary_compiler/test_corr/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ WITH `bfcte_0` AS (
22
SELECT
33
`int64_col`,
44
`float64_col`
5-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
66
), `bfcte_1` AS (
77
SELECT
88
CORR(`int64_col`, `float64_col`) AS `bfcol_2`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_binary_compiler/test_cov/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ WITH `bfcte_0` AS (
22
SELECT
33
`int64_col`,
44
`float64_col`
5-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
66
), `bfcte_1` AS (
77
SELECT
88
COVAR_SAMP(`int64_col`, `float64_col`) AS `bfcol_2`
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
SELECT
22
ROW_NUMBER() OVER () - 1 AS `row_number`
3-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
SELECT
22
ROW_NUMBER() OVER (ORDER BY `int64_col` ASC NULLS LAST) - 1 AS `row_number`
3-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
WITH `bfcte_0` AS (
22
SELECT
33
*
4-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
55
), `bfcte_1` AS (
66
SELECT
77
COUNT(1) AS `bfcol_32`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_ordered_unary_compiler/test_array_agg/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
WITH `bfcte_0` AS (
22
SELECT
33
`int64_col`
4-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
55
), `bfcte_1` AS (
66
SELECT
77
ARRAY_AGG(`int64_col` IGNORE NULLS ORDER BY `int64_col` IS NULL ASC, `int64_col` ASC) AS `bfcol_1`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_ordered_unary_compiler/test_string_agg/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
WITH `bfcte_0` AS (
22
SELECT
33
`string_col`
4-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
55
), `bfcte_1` AS (
66
SELECT
77
COALESCE(

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ WITH `bfcte_0` AS (
22
SELECT
33
`bool_col`,
44
`int64_col`
5-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
66
), `bfcte_1` AS (
77
SELECT
88
COALESCE(LOGICAL_AND(`bool_col`), TRUE) AS `bfcol_2`,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
SELECT
22
COALESCE(LOGICAL_AND(`bool_col`) OVER (), TRUE) AS `agg_bool`
3-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

0 commit comments

Comments
 (0)