Skip to content

Commit e05cfce

Browse files
committed
refactor: add uid generator and encasualate query as cte
1 parent f3fd7e2 commit e05cfce

File tree

10 files changed

+343
-280
lines changed

10 files changed

+343
-280
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ repos:
2020
hooks:
2121
- id: trailing-whitespace
2222
- id: end-of-file-fixer
23+
exclude: "^tests/unit/core/compile/sqlglot/snapshots"
2324
- id: check-yaml
2425
- repo: https://github.com/pycqa/isort
2526
rev: 5.12.0

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 48 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515

1616
import dataclasses
1717
import functools
18-
import itertools
1918
import typing
2019

2120
from google.cloud import bigquery
2221
import pyarrow as pa
2322
import sqlglot.expressions as sge
2423

25-
from bigframes.core import expression, identifiers, nodes, rewrite
24+
from bigframes.core import expression, guid, nodes, rewrite
2625
from bigframes.core.compile import configs
2726
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2827
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
@@ -33,6 +32,9 @@
3332
class SQLGlotCompiler:
3433
"""Compiles BigFrame nodes into SQL using SQLGlot."""
3534

35+
uid_gen: guid.SequentialUIDGenerator
36+
"""Generator for unique identifiers."""
37+
3638
def compile(
3739
self,
3840
node: nodes.BigFrameNode,
@@ -82,8 +84,8 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
8284
result_node = typing.cast(
8385
nodes.ResultNode, rewrite.column_pruning(result_node)
8486
)
85-
result_node = _remap_variables(result_node)
86-
sql = self._compile_result_node(result_node)
87+
remap_node, _ = rewrite.remap_variables(result_node, self.uid_gen)
88+
sql = self._compile_result_node(typing.cast(nodes.ResultNode, remap_node))
8789
return configs.CompileResult(
8890
sql, result_node.schema.to_bigquery(), result_node.order_by
8991
)
@@ -92,8 +94,8 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
9294
result_node = dataclasses.replace(result_node, order_by=None)
9395
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
9496

95-
result_node = _remap_variables(result_node)
96-
sql = self._compile_result_node(result_node)
97+
remap_node, _ = rewrite.remap_variables(result_node, self.uid_gen)
98+
sql = self._compile_result_node(typing.cast(nodes.ResultNode, remap_node))
9799
# Return the ordering iff no extra columns are needed to define the row order
98100
if ordering is not None:
99101
output_order = (
@@ -107,62 +109,53 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
107109
)
108110

109111
def _compile_result_node(self, root: nodes.ResultNode) -> str:
110-
sqlglot_ir = compile_node(root.child)
112+
sqlglot_ir = self.compile_node(root.child)
111113
# TODO: add order_by, limit, and selections to sqlglot_expr
112114
return sqlglot_ir.sql
113115

116+
@functools.lru_cache(maxsize=5000)
117+
def compile_node(self, node: nodes.BigFrameNode) -> ir.SQLGlotIR:
118+
"""Compiles node into CompileArrayValue. Caches result."""
119+
return node.reduce_up(
120+
lambda node, children: self._compile_node(node, *children)
121+
)
114122

115-
def _replace_unsupported_ops(node: nodes.BigFrameNode):
116-
node = nodes.bottom_up(node, rewrite.rewrite_slice)
117-
node = nodes.bottom_up(node, rewrite.rewrite_timedelta_expressions)
118-
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
119-
return node
120-
121-
122-
def _remap_variables(node: nodes.ResultNode) -> nodes.ResultNode:
123-
"""Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs."""
124-
125-
def anonymous_column_ids() -> typing.Generator[identifiers.ColumnId, None, None]:
126-
for i in itertools.count():
127-
yield identifiers.ColumnId(name=f"bfcol_{i}")
128-
129-
result_node, _ = rewrite.remap_variables(node, anonymous_column_ids())
130-
return typing.cast(nodes.ResultNode, result_node)
131-
132-
133-
@functools.lru_cache(maxsize=5000)
134-
def compile_node(node: nodes.BigFrameNode) -> ir.SQLGlotIR:
135-
"""Compiles node into CompileArrayValue. Caches result."""
136-
return node.reduce_up(lambda node, children: _compile_node(node, *children))
137-
138-
139-
@functools.singledispatch
140-
def _compile_node(
141-
node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR
142-
) -> ir.SQLGlotIR:
143-
"""Defines transformation but isn't cached, always use compile_node instead"""
144-
raise ValueError(f"Can't compile unrecognized node: {node}")
123+
@functools.singledispatchmethod
124+
def _compile_node(
125+
self, node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR
126+
) -> ir.SQLGlotIR:
127+
"""Defines transformation but isn't cached, always use compile_node instead"""
128+
raise ValueError(f"Can't compile unrecognized node: {node}")
129+
130+
@_compile_node.register
131+
def compile_readlocal(self, node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR:
132+
pa_table = node.local_data_source.data
133+
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
134+
pa_table = pa_table.rename_columns(
135+
[item.id.sql for item in node.scan_list.items]
136+
)
145137

138+
offsets = node.offsets_col.sql if node.offsets_col else None
139+
if offsets:
140+
pa_table = pa_table.append_column(
141+
offsets, pa.array(range(pa_table.num_rows), type=pa.int64())
142+
)
146143

147-
@_compile_node.register
148-
def compile_readlocal(node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR:
149-
pa_table = node.local_data_source.data
150-
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
151-
pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items])
144+
return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema, uid_gen=self.uid_gen)
152145

153-
offsets = node.offsets_col.sql if node.offsets_col else None
154-
if offsets:
155-
pa_table = pa_table.append_column(
156-
offsets, pa.array(range(pa_table.num_rows), type=pa.int64())
146+
@_compile_node.register
147+
def compile_selection(
148+
self, node: nodes.SelectionNode, child: ir.SQLGlotIR
149+
) -> ir.SQLGlotIR:
150+
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
151+
(id.sql, scalar_compiler.compile_scalar_expression(expr))
152+
for expr, id in node.input_output_pairs
157153
)
158-
159-
return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema)
154+
return child.select(selected_cols)
160155

161156

162-
@_compile_node.register
163-
def compile_selection(node: nodes.SelectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
164-
select_cols: typing.Dict[str, sge.Expression] = {
165-
id.name: scalar_compiler.compile_scalar_expression(expr)
166-
for expr, id in node.input_output_pairs
167-
}
168-
return child.select(select_cols)
157+
def _replace_unsupported_ops(node: nodes.BigFrameNode):
158+
node = nodes.bottom_up(node, rewrite.rewrite_slice)
159+
node = nodes.bottom_up(node, rewrite.rewrite_timedelta_expressions)
160+
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
161+
return node

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import sqlglot.expressions as sge
2424

2525
from bigframes import dtypes
26+
from bigframes.core import guid
2627
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
2728
import bigframes.core.local_data as local_data
2829
import bigframes.core.schema as schemata
@@ -52,14 +53,20 @@ class SQLGlotIR:
5253
pretty: bool = True
5354
"""Whether to pretty-print the generated SQL."""
5455

56+
uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator()
57+
"""Generator for unique identifiers."""
58+
5559
@property
5660
def sql(self) -> str:
5761
"""Generate SQL string from the given expression."""
5862
return self.expr.sql(dialect=self.dialect, pretty=self.pretty)
5963

6064
@classmethod
6165
def from_pyarrow(
62-
cls, pa_table: pa.Table, schema: schemata.ArraySchema
66+
cls,
67+
pa_table: pa.Table,
68+
schema: schemata.ArraySchema,
69+
uid_gen: guid.SequentialUIDGenerator,
6370
) -> SQLGlotIR:
6471
"""Builds SQLGlot expression from pyarrow table."""
6572
dtype_expr = sge.DataType(
@@ -95,21 +102,44 @@ def from_pyarrow(
95102
),
96103
],
97104
)
98-
return cls(expr=sg.select(sge.Star()).from_(expr))
105+
return cls(expr=sg.select(sge.Star()).from_(expr), uid_gen=uid_gen)
99106

100107
def select(
101108
self,
102-
select_cols: typing.Dict[str, sge.Expression],
109+
selected_cols: tuple[tuple[str, sge.Expression], ...],
103110
) -> SQLGlotIR:
104-
selected_cols = [
111+
cols_expr = [
105112
sge.Alias(
106113
this=expr,
107114
alias=sge.to_identifier(id, quoted=self.quoted),
108115
)
109-
for id, expr in select_cols.items()
116+
for id, expr in selected_cols
110117
]
111-
expr = self.expr.select(*selected_cols, append=False)
112-
return SQLGlotIR(expr=expr)
118+
new_expr = self._encapsulate_as_cte().select(*cols_expr, append=False)
119+
return SQLGlotIR(expr=new_expr)
120+
121+
def _encapsulate_as_cte(
122+
self,
123+
) -> sge.Select:
124+
"""Transforms a given sge.Select query by pushing its main SELECT statement
125+
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
126+
for the new query."""
127+
select_expr = self.expr.copy()
128+
129+
existing_ctes = select_expr.args.pop("with", [])
130+
new_cte_name = sge.to_identifier(
131+
self.uid_gen.generate_sequential_uid("bfcte_"), quoted=self.quoted
132+
)
133+
new_cte = sge.CTE(
134+
this=select_expr,
135+
alias=new_cte_name,
136+
)
137+
new_with_clause = sge.With(expressions=existing_ctes + [new_cte])
138+
new_select_expr = (
139+
sge.Select().select(sge.Star()).from_(sge.Table(this=new_cte_name))
140+
)
141+
new_select_expr.set("with", new_with_clause)
142+
return new_select_expr
113143

114144

115145
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:

bigframes/core/guid.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,21 @@ def generate_guid(prefix="col_"):
1919
global _GUID_COUNTER
2020
_GUID_COUNTER += 1
2121
return f"bfuid_{prefix}{_GUID_COUNTER}"
22+
23+
24+
class SequentialUIDGenerator:
25+
"""
26+
Generates sequential-like UIDs with multiple prefixes, e.g., "t0", "t1", "c0", "t2", etc.
27+
"""
28+
29+
def __init__(self):
30+
self.prefix_counters = {}
31+
32+
def generate_sequential_uid(self, prefix: str) -> str:
33+
"""Generates a sequential UID with specified prefix."""
34+
if prefix not in self.prefix_counters:
35+
self.prefix_counters[prefix] = 0
36+
37+
uid = f"{prefix}{self.prefix_counters[prefix]}"
38+
self.prefix_counters[prefix] += 1
39+
return uid

bigframes/core/rewrite/identifiers.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,25 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import Generator, Tuple
16+
from typing import Tuple
1717

18-
import bigframes.core.identifiers
19-
import bigframes.core.nodes
18+
from bigframes.core import guid, identifiers, nodes
2019

2120

2221
# TODO: May as well just outright remove selection nodes in this process.
2322
def remap_variables(
24-
root: bigframes.core.nodes.BigFrameNode,
25-
id_generator: Generator[bigframes.core.identifiers.ColumnId, None, None],
26-
) -> Tuple[
27-
bigframes.core.nodes.BigFrameNode,
28-
dict[bigframes.core.identifiers.ColumnId, bigframes.core.identifiers.ColumnId],
29-
]:
30-
"""
31-
Remap all variables in the BFET using the id_generator.
23+
root: nodes.BigFrameNode,
24+
uid_gen: guid.SequentialUIDGenerator,
25+
) -> Tuple[nodes.BigFrameNode, dict[identifiers.ColumnId, identifiers.ColumnId],]:
26+
"""Remaps `ColumnId`s in the BFET to produce deterministic and sequential UIDs.
3227
3328
Note: this will convert a DAG to a tree.
3429
"""
3530
child_replacement_map = dict()
3631
ref_mapping = dict()
3732
# Sequential ids are assigned bottom-up left-to-right
3833
for child in root.child_nodes:
39-
new_child, child_var_mapping = remap_variables(child, id_generator=id_generator)
34+
new_child, child_var_mapping = remap_variables(child, uid_gen=uid_gen)
4035
child_replacement_map[child] = new_child
4136
ref_mapping.update(child_var_mapping)
4237

@@ -47,7 +42,10 @@ def remap_variables(
4742

4843
with_new_refs = with_new_children.remap_refs(ref_mapping)
4944

50-
node_var_mapping = {old_id: next(id_generator) for old_id in root.node_defined_ids}
45+
node_var_mapping = {
46+
old_id: identifiers.ColumnId(name=uid_gen.generate_sequential_uid("bfcol_"))
47+
for old_id in root.node_defined_ids
48+
}
5149
with_new_vars = with_new_refs.remap_vars(node_var_mapping)
5250
with_new_vars._validate()
5351

tests/unit/core/compile/sqlglot/compiler_session.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import bigframes.core
2020
import bigframes.core.compile.sqlglot as sqlglot
21+
import bigframes.core.guid
2122
import bigframes.dataframe
2223
import bigframes.session.executor
2324
import bigframes.session.metrics
@@ -27,7 +28,7 @@
2728
class SQLCompilerExecutor(bigframes.session.executor.Executor):
2829
"""Executor for SQL compilation using sqlglot."""
2930

30-
compiler = sqlglot.SQLGlotCompiler()
31+
compiler = sqlglot
3132

3233
def to_sql(
3334
self,
@@ -41,7 +42,9 @@ def to_sql(
4142

4243
# Compared with BigQueryCachingExecutor, SQLCompilerExecutor skips
4344
# caching the subtree.
44-
return self.compiler.compile(array_value.node, ordered=ordered)
45+
return self.compiler.SQLGlotCompiler(
46+
uid_gen=bigframes.core.guid.SequentialUIDGenerator()
47+
).compile(array_value.node, ordered=ordered)
4548

4649

4750
class SQLCompilerSession(bigframes.session.Session):

0 commit comments

Comments
 (0)