Skip to content

Commit 86b7504

Browse files
authored
refactor: add uid generator and encasualate query as cte in SQLGlotCompiler (#1679)
1 parent acd09b1 commit 86b7504

File tree

10 files changed

+352
-276
lines changed

10 files changed

+352
-276
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ repos:
2020
hooks:
2121
- id: trailing-whitespace
2222
- id: end-of-file-fixer
23+
exclude: "^tests/unit/core/compile/sqlglot/snapshots"
2324
- id: check-yaml
2425
- repo: https://github.com/pycqa/isort
2526
rev: 5.12.0

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,28 @@
1515

1616
import dataclasses
1717
import functools
18-
import itertools
1918
import typing
2019

2120
from google.cloud import bigquery
2221
import pyarrow as pa
2322
import sqlglot.expressions as sge
2423

25-
from bigframes.core import expression, identifiers, nodes, rewrite
24+
from bigframes.core import expression, guid, identifiers, nodes, rewrite
2625
from bigframes.core.compile import configs
2726
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2827
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2928
import bigframes.core.ordering as bf_ordering
3029

3130

32-
@dataclasses.dataclass(frozen=True)
3331
class SQLGlotCompiler:
3432
"""Compiles BigFrame nodes into SQL using SQLGlot."""
3533

34+
uid_gen: guid.SequentialUIDGenerator
35+
"""Generator for unique identifiers."""
36+
37+
def __init__(self):
38+
self.uid_gen = guid.SequentialUIDGenerator()
39+
3640
def compile(
3741
self,
3842
node: nodes.BigFrameNode,
@@ -82,7 +86,7 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
8286
result_node = typing.cast(
8387
nodes.ResultNode, rewrite.column_pruning(result_node)
8488
)
85-
result_node = _remap_variables(result_node)
89+
result_node = self._remap_variables(result_node)
8690
sql = self._compile_result_node(result_node)
8791
return configs.CompileResult(
8892
sql, result_node.schema.to_bigquery(), result_node.order_by
@@ -92,7 +96,7 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
9296
result_node = dataclasses.replace(result_node, order_by=None)
9397
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
9498

95-
result_node = _remap_variables(result_node)
99+
result_node = self._remap_variables(result_node)
96100
sql = self._compile_result_node(result_node)
97101
# Return the ordering iff no extra columns are needed to define the row order
98102
if ordering is not None:
@@ -106,63 +110,62 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
106110
sql, result_node.schema.to_bigquery(), output_order
107111
)
108112

113+
def _remap_variables(self, node: nodes.ResultNode) -> nodes.ResultNode:
114+
"""Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs."""
115+
116+
result_node, _ = rewrite.remap_variables(
117+
node, map(identifiers.ColumnId, self.uid_gen.get_uid_stream("bfcol_"))
118+
)
119+
return typing.cast(nodes.ResultNode, result_node)
120+
109121
def _compile_result_node(self, root: nodes.ResultNode) -> str:
110-
sqlglot_ir = compile_node(root.child)
122+
sqlglot_ir = self.compile_node(root.child)
111123
# TODO: add order_by, limit, and selections to sqlglot_expr
112124
return sqlglot_ir.sql
113125

126+
@functools.lru_cache(maxsize=5000)
127+
def compile_node(self, node: nodes.BigFrameNode) -> ir.SQLGlotIR:
128+
"""Compiles node into CompileArrayValue. Caches result."""
129+
return node.reduce_up(
130+
lambda node, children: self._compile_node(node, *children)
131+
)
114132

115-
def _replace_unsupported_ops(node: nodes.BigFrameNode):
116-
node = nodes.bottom_up(node, rewrite.rewrite_slice)
117-
node = nodes.bottom_up(node, rewrite.rewrite_timedelta_expressions)
118-
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
119-
return node
120-
121-
122-
def _remap_variables(node: nodes.ResultNode) -> nodes.ResultNode:
123-
"""Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs."""
124-
125-
def anonymous_column_ids() -> typing.Generator[identifiers.ColumnId, None, None]:
126-
for i in itertools.count():
127-
yield identifiers.ColumnId(name=f"bfcol_{i}")
128-
129-
result_node, _ = rewrite.remap_variables(node, anonymous_column_ids())
130-
return typing.cast(nodes.ResultNode, result_node)
131-
132-
133-
@functools.lru_cache(maxsize=5000)
134-
def compile_node(node: nodes.BigFrameNode) -> ir.SQLGlotIR:
135-
"""Compiles node into CompileArrayValue. Caches result."""
136-
return node.reduce_up(lambda node, children: _compile_node(node, *children))
137-
138-
139-
@functools.singledispatch
140-
def _compile_node(
141-
node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR
142-
) -> ir.SQLGlotIR:
143-
"""Defines transformation but isn't cached, always use compile_node instead"""
144-
raise ValueError(f"Can't compile unrecognized node: {node}")
133+
@functools.singledispatchmethod
134+
def _compile_node(
135+
self, node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR
136+
) -> ir.SQLGlotIR:
137+
"""Defines transformation but isn't cached, always use compile_node instead"""
138+
raise ValueError(f"Can't compile unrecognized node: {node}")
139+
140+
@_compile_node.register
141+
def compile_readlocal(self, node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR:
142+
pa_table = node.local_data_source.data
143+
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
144+
pa_table = pa_table.rename_columns(
145+
[item.id.sql for item in node.scan_list.items]
146+
)
145147

148+
offsets = node.offsets_col.sql if node.offsets_col else None
149+
if offsets:
150+
pa_table = pa_table.append_column(
151+
offsets, pa.array(range(pa_table.num_rows), type=pa.int64())
152+
)
146153

147-
@_compile_node.register
148-
def compile_readlocal(node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR:
149-
pa_table = node.local_data_source.data
150-
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
151-
pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items])
154+
return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema, uid_gen=self.uid_gen)
152155

153-
offsets = node.offsets_col.sql if node.offsets_col else None
154-
if offsets:
155-
pa_table = pa_table.append_column(
156-
offsets, pa.array(range(pa_table.num_rows), type=pa.int64())
156+
@_compile_node.register
157+
def compile_selection(
158+
self, node: nodes.SelectionNode, child: ir.SQLGlotIR
159+
) -> ir.SQLGlotIR:
160+
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
161+
(id.sql, scalar_compiler.compile_scalar_expression(expr))
162+
for expr, id in node.input_output_pairs
157163
)
164+
return child.select(selected_cols)
158165

159-
return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema)
160166

161-
162-
@_compile_node.register
163-
def compile_selection(node: nodes.SelectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
164-
select_cols: typing.Dict[str, sge.Expression] = {
165-
id.name: scalar_compiler.compile_scalar_expression(expr)
166-
for expr, id in node.input_output_pairs
167-
}
168-
return child.select(select_cols)
167+
def _replace_unsupported_ops(node: nodes.BigFrameNode):
168+
node = nodes.bottom_up(node, rewrite.rewrite_slice)
169+
node = nodes.bottom_up(node, rewrite.rewrite_timedelta_expressions)
170+
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
171+
return node

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import sqlglot.expressions as sge
2424

2525
from bigframes import dtypes
26+
from bigframes.core import guid
2627
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
2728
import bigframes.core.local_data as local_data
2829
import bigframes.core.schema as schemata
@@ -52,14 +53,20 @@ class SQLGlotIR:
5253
pretty: bool = True
5354
"""Whether to pretty-print the generated SQL."""
5455

56+
uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator()
57+
"""Generator for unique identifiers."""
58+
5559
@property
5660
def sql(self) -> str:
5761
"""Generate SQL string from the given expression."""
5862
return self.expr.sql(dialect=self.dialect, pretty=self.pretty)
5963

6064
@classmethod
6165
def from_pyarrow(
62-
cls, pa_table: pa.Table, schema: schemata.ArraySchema
66+
cls,
67+
pa_table: pa.Table,
68+
schema: schemata.ArraySchema,
69+
uid_gen: guid.SequentialUIDGenerator,
6370
) -> SQLGlotIR:
6471
"""Builds SQLGlot expression from pyarrow table."""
6572
dtype_expr = sge.DataType(
@@ -95,21 +102,44 @@ def from_pyarrow(
95102
),
96103
],
97104
)
98-
return cls(expr=sg.select(sge.Star()).from_(expr))
105+
return cls(expr=sg.select(sge.Star()).from_(expr), uid_gen=uid_gen)
99106

100107
def select(
101108
self,
102-
select_cols: typing.Dict[str, sge.Expression],
109+
selected_cols: tuple[tuple[str, sge.Expression], ...],
103110
) -> SQLGlotIR:
104-
selected_cols = [
111+
cols_expr = [
105112
sge.Alias(
106113
this=expr,
107114
alias=sge.to_identifier(id, quoted=self.quoted),
108115
)
109-
for id, expr in select_cols.items()
116+
for id, expr in selected_cols
110117
]
111-
expr = self.expr.select(*selected_cols, append=False)
112-
return SQLGlotIR(expr=expr)
118+
new_expr = self._encapsulate_as_cte().select(*cols_expr, append=False)
119+
return SQLGlotIR(expr=new_expr)
120+
121+
def _encapsulate_as_cte(
122+
self,
123+
) -> sge.Select:
124+
"""Transforms a given sge.Select query by pushing its main SELECT statement
125+
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
126+
for the new query."""
127+
select_expr = self.expr.copy()
128+
129+
existing_ctes = select_expr.args.pop("with", [])
130+
new_cte_name = sge.to_identifier(
131+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
132+
)
133+
new_cte = sge.CTE(
134+
this=select_expr,
135+
alias=new_cte_name,
136+
)
137+
new_with_clause = sge.With(expressions=existing_ctes + [new_cte])
138+
new_select_expr = (
139+
sge.Select().select(sge.Star()).from_(sge.Table(this=new_cte_name))
140+
)
141+
new_select_expr.set("with", new_with_clause)
142+
return new_select_expr
113143

114144

115145
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:

bigframes/core/guid.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,32 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import typing
16+
1517
_GUID_COUNTER = 0
1618

1719

1820
def generate_guid(prefix="col_"):
1921
global _GUID_COUNTER
2022
_GUID_COUNTER += 1
2123
return f"bfuid_{prefix}{_GUID_COUNTER}"
24+
25+
26+
class SequentialUIDGenerator:
27+
"""Produces a sequence of UIDs, such as {"t0", "t1", "c0", "t2", ...}, by
28+
cycling through provided prefixes (e.g., "t" and "c").
29+
Note: this function is not thread-safe.
30+
"""
31+
32+
def __init__(self):
33+
self.prefix_counters: typing.Dict[str, int] = {}
34+
35+
def get_uid_stream(self, prefix: str) -> typing.Generator[str, None, None]:
36+
"""Yields a continuous stream of raw UID strings for the given prefix."""
37+
if prefix not in self.prefix_counters:
38+
self.prefix_counters[prefix] = 0
39+
40+
while True:
41+
uid = f"{prefix}{self.prefix_counters[prefix]}"
42+
self.prefix_counters[prefix] += 1
43+
yield uid

bigframes/core/rewrite/identifiers.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,20 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import Generator, Tuple
16+
import typing
1717

18-
import bigframes.core.identifiers
19-
import bigframes.core.nodes
18+
from bigframes.core import identifiers, nodes
2019

2120

2221
# TODO: May as well just outright remove selection nodes in this process.
2322
def remap_variables(
24-
root: bigframes.core.nodes.BigFrameNode,
25-
id_generator: Generator[bigframes.core.identifiers.ColumnId, None, None],
26-
) -> Tuple[
27-
bigframes.core.nodes.BigFrameNode,
28-
dict[bigframes.core.identifiers.ColumnId, bigframes.core.identifiers.ColumnId],
23+
root: nodes.BigFrameNode,
24+
id_generator: typing.Iterator[identifiers.ColumnId],
25+
) -> typing.Tuple[
26+
nodes.BigFrameNode,
27+
dict[identifiers.ColumnId, identifiers.ColumnId],
2928
]:
30-
"""
31-
Remap all variables in the BFET using the id_generator.
29+
"""Remaps `ColumnId`s in the BFET to produce deterministic and sequential UIDs.
3230
3331
Note: this will convert a DAG to a tree.
3432
"""

tests/unit/core/compile/sqlglot/compiler_session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
class SQLCompilerExecutor(bigframes.session.executor.Executor):
2828
"""Executor for SQL compilation using sqlglot."""
2929

30-
compiler = sqlglot.SQLGlotCompiler()
30+
compiler = sqlglot
3131

3232
def to_sql(
3333
self,
@@ -41,7 +41,9 @@ def to_sql(
4141

4242
# Compared with BigQueryCachingExecutor, SQLCompilerExecutor skips
4343
# caching the subtree.
44-
return self.compiler.compile(array_value.node, ordered=ordered)
44+
return self.compiler.SQLGlotCompiler().compile(
45+
array_value.node, ordered=ordered
46+
)
4547

4648

4749
class SQLCompilerSession(bigframes.session.Session):

0 commit comments

Comments
 (0)