15
15
16
16
import dataclasses
17
17
import functools
18
- import itertools
19
18
import typing
20
19
21
20
from google .cloud import bigquery
22
21
import pyarrow as pa
23
22
import sqlglot .expressions as sge
24
23
25
- from bigframes .core import expression , identifiers , nodes , rewrite
24
+ from bigframes .core import expression , guid , identifiers , nodes , rewrite
26
25
from bigframes .core .compile import configs
27
26
import bigframes .core .compile .sqlglot .scalar_compiler as scalar_compiler
28
27
import bigframes .core .compile .sqlglot .sqlglot_ir as ir
29
28
import bigframes .core .ordering as bf_ordering
30
29
31
30
32
- @dataclasses .dataclass (frozen = True )
33
31
class SQLGlotCompiler :
34
32
"""Compiles BigFrame nodes into SQL using SQLGlot."""
35
33
34
+ uid_gen : guid .SequentialUIDGenerator
35
+ """Generator for unique identifiers."""
36
+
37
+ def __init__ (self ):
38
+ self .uid_gen = guid .SequentialUIDGenerator ()
39
+
36
40
def compile (
37
41
self ,
38
42
node : nodes .BigFrameNode ,
@@ -82,7 +86,7 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
82
86
result_node = typing .cast (
83
87
nodes .ResultNode , rewrite .column_pruning (result_node )
84
88
)
85
- result_node = _remap_variables (result_node )
89
+ result_node = self . _remap_variables (result_node )
86
90
sql = self ._compile_result_node (result_node )
87
91
return configs .CompileResult (
88
92
sql , result_node .schema .to_bigquery (), result_node .order_by
@@ -92,7 +96,7 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
92
96
result_node = dataclasses .replace (result_node , order_by = None )
93
97
result_node = typing .cast (nodes .ResultNode , rewrite .column_pruning (result_node ))
94
98
95
- result_node = _remap_variables (result_node )
99
+ result_node = self . _remap_variables (result_node )
96
100
sql = self ._compile_result_node (result_node )
97
101
# Return the ordering iff no extra columns are needed to define the row order
98
102
if ordering is not None :
@@ -106,63 +110,62 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
106
110
sql , result_node .schema .to_bigquery (), output_order
107
111
)
108
112
113
+ def _remap_variables (self , node : nodes .ResultNode ) -> nodes .ResultNode :
114
+ """Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs."""
115
+
116
+ result_node , _ = rewrite .remap_variables (
117
+ node , map (identifiers .ColumnId , self .uid_gen .get_uid_stream ("bfcol_" ))
118
+ )
119
+ return typing .cast (nodes .ResultNode , result_node )
120
+
109
121
def _compile_result_node (self , root : nodes .ResultNode ) -> str :
110
- sqlglot_ir = compile_node (root .child )
122
+ sqlglot_ir = self . compile_node (root .child )
111
123
# TODO: add order_by, limit, and selections to sqlglot_expr
112
124
return sqlglot_ir .sql
113
125
126
+ @functools .lru_cache (maxsize = 5000 )
127
+ def compile_node (self , node : nodes .BigFrameNode ) -> ir .SQLGlotIR :
128
+ """Compiles node into CompileArrayValue. Caches result."""
129
+ return node .reduce_up (
130
+ lambda node , children : self ._compile_node (node , * children )
131
+ )
114
132
115
- def _replace_unsupported_ops (node : nodes .BigFrameNode ):
116
- node = nodes .bottom_up (node , rewrite .rewrite_slice )
117
- node = nodes .bottom_up (node , rewrite .rewrite_timedelta_expressions )
118
- node = nodes .bottom_up (node , rewrite .rewrite_range_rolling )
119
- return node
120
-
121
-
122
- def _remap_variables (node : nodes .ResultNode ) -> nodes .ResultNode :
123
- """Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs."""
124
-
125
- def anonymous_column_ids () -> typing .Generator [identifiers .ColumnId , None , None ]:
126
- for i in itertools .count ():
127
- yield identifiers .ColumnId (name = f"bfcol_{ i } " )
128
-
129
- result_node , _ = rewrite .remap_variables (node , anonymous_column_ids ())
130
- return typing .cast (nodes .ResultNode , result_node )
131
-
132
-
133
- @functools .lru_cache (maxsize = 5000 )
134
- def compile_node (node : nodes .BigFrameNode ) -> ir .SQLGlotIR :
135
- """Compiles node into CompileArrayValue. Caches result."""
136
- return node .reduce_up (lambda node , children : _compile_node (node , * children ))
137
-
138
-
139
- @functools .singledispatch
140
- def _compile_node (
141
- node : nodes .BigFrameNode , * compiled_children : ir .SQLGlotIR
142
- ) -> ir .SQLGlotIR :
143
- """Defines transformation but isn't cached, always use compile_node instead"""
144
- raise ValueError (f"Can't compile unrecognized node: { node } " )
133
+ @functools .singledispatchmethod
134
+ def _compile_node (
135
+ self , node : nodes .BigFrameNode , * compiled_children : ir .SQLGlotIR
136
+ ) -> ir .SQLGlotIR :
137
+ """Defines transformation but isn't cached, always use compile_node instead"""
138
+ raise ValueError (f"Can't compile unrecognized node: { node } " )
139
+
140
+ @_compile_node .register
141
+ def compile_readlocal (self , node : nodes .ReadLocalNode , * args ) -> ir .SQLGlotIR :
142
+ pa_table = node .local_data_source .data
143
+ pa_table = pa_table .select ([item .source_id for item in node .scan_list .items ])
144
+ pa_table = pa_table .rename_columns (
145
+ [item .id .sql for item in node .scan_list .items ]
146
+ )
145
147
148
+ offsets = node .offsets_col .sql if node .offsets_col else None
149
+ if offsets :
150
+ pa_table = pa_table .append_column (
151
+ offsets , pa .array (range (pa_table .num_rows ), type = pa .int64 ())
152
+ )
146
153
147
- @_compile_node .register
148
- def compile_readlocal (node : nodes .ReadLocalNode , * args ) -> ir .SQLGlotIR :
149
- pa_table = node .local_data_source .data
150
- pa_table = pa_table .select ([item .source_id for item in node .scan_list .items ])
151
- pa_table = pa_table .rename_columns ([item .id .sql for item in node .scan_list .items ])
154
+ return ir .SQLGlotIR .from_pyarrow (pa_table , node .schema , uid_gen = self .uid_gen )
152
155
153
- offsets = node .offsets_col .sql if node .offsets_col else None
154
- if offsets :
155
- pa_table = pa_table .append_column (
156
- offsets , pa .array (range (pa_table .num_rows ), type = pa .int64 ())
156
+ @_compile_node .register
157
+ def compile_selection (
158
+ self , node : nodes .SelectionNode , child : ir .SQLGlotIR
159
+ ) -> ir .SQLGlotIR :
160
+ selected_cols : tuple [tuple [str , sge .Expression ], ...] = tuple (
161
+ (id .sql , scalar_compiler .compile_scalar_expression (expr ))
162
+ for expr , id in node .input_output_pairs
157
163
)
164
+ return child .select (selected_cols )
158
165
159
- return ir .SQLGlotIR .from_pyarrow (pa_table , node .schema )
160
166
161
-
162
- @_compile_node .register
163
- def compile_selection (node : nodes .SelectionNode , child : ir .SQLGlotIR ) -> ir .SQLGlotIR :
164
- select_cols : typing .Dict [str , sge .Expression ] = {
165
- id .name : scalar_compiler .compile_scalar_expression (expr )
166
- for expr , id in node .input_output_pairs
167
- }
168
- return child .select (select_cols )
167
+ def _replace_unsupported_ops (node : nodes .BigFrameNode ):
168
+ node = nodes .bottom_up (node , rewrite .rewrite_slice )
169
+ node = nodes .bottom_up (node , rewrite .rewrite_timedelta_expressions )
170
+ node = nodes .bottom_up (node , rewrite .rewrite_range_rolling )
171
+ return node
0 commit comments