15
15
16
16
import dataclasses
17
17
import functools
18
- import itertools
19
18
import typing
20
19
21
20
from google .cloud import bigquery
22
21
import pyarrow as pa
23
22
import sqlglot .expressions as sge
24
23
25
- from bigframes .core import expression , identifiers , nodes , rewrite
24
+ from bigframes .core import expression , guid , nodes , rewrite
26
25
from bigframes .core .compile import configs
27
26
import bigframes .core .compile .sqlglot .scalar_compiler as scalar_compiler
28
27
import bigframes .core .compile .sqlglot .sqlglot_ir as ir
33
32
class SQLGlotCompiler :
34
33
"""Compiles BigFrame nodes into SQL using SQLGlot."""
35
34
35
+ uid_gen : guid .SequentialUIDGenerator
36
+ """Generator for unique identifiers."""
37
+
36
38
def compile (
37
39
self ,
38
40
node : nodes .BigFrameNode ,
@@ -82,8 +84,8 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
82
84
result_node = typing .cast (
83
85
nodes .ResultNode , rewrite .column_pruning (result_node )
84
86
)
85
- result_node = _remap_variables (result_node )
86
- sql = self ._compile_result_node (result_node )
87
+ remap_node , _ = rewrite . remap_variables (result_node , self . uid_gen )
88
+ sql = self ._compile_result_node (typing . cast ( nodes . ResultNode , remap_node ) )
87
89
return configs .CompileResult (
88
90
sql , result_node .schema .to_bigquery (), result_node .order_by
89
91
)
@@ -92,8 +94,8 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
92
94
result_node = dataclasses .replace (result_node , order_by = None )
93
95
result_node = typing .cast (nodes .ResultNode , rewrite .column_pruning (result_node ))
94
96
95
- result_node = _remap_variables (result_node )
96
- sql = self ._compile_result_node (result_node )
97
+ remap_node , _ = rewrite . remap_variables (result_node , self . uid_gen )
98
+ sql = self ._compile_result_node (typing . cast ( nodes . ResultNode , remap_node ) )
97
99
# Return the ordering iff no extra columns are needed to define the row order
98
100
if ordering is not None :
99
101
output_order = (
@@ -107,62 +109,53 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
107
109
)
108
110
109
111
def _compile_result_node (self , root : nodes .ResultNode ) -> str :
110
- sqlglot_ir = compile_node (root .child )
112
+ sqlglot_ir = self . compile_node (root .child )
111
113
# TODO: add order_by, limit, and selections to sqlglot_expr
112
114
return sqlglot_ir .sql
113
115
116
+ @functools .lru_cache (maxsize = 5000 )
117
+ def compile_node (self , node : nodes .BigFrameNode ) -> ir .SQLGlotIR :
118
+ """Compiles node into CompileArrayValue. Caches result."""
119
+ return node .reduce_up (
120
+ lambda node , children : self ._compile_node (node , * children )
121
+ )
114
122
115
- def _replace_unsupported_ops (node : nodes .BigFrameNode ):
116
- node = nodes .bottom_up (node , rewrite .rewrite_slice )
117
- node = nodes .bottom_up (node , rewrite .rewrite_timedelta_expressions )
118
- node = nodes .bottom_up (node , rewrite .rewrite_range_rolling )
119
- return node
120
-
121
-
122
- def _remap_variables (node : nodes .ResultNode ) -> nodes .ResultNode :
123
- """Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs."""
124
-
125
- def anonymous_column_ids () -> typing .Generator [identifiers .ColumnId , None , None ]:
126
- for i in itertools .count ():
127
- yield identifiers .ColumnId (name = f"bfcol_{ i } " )
128
-
129
- result_node , _ = rewrite .remap_variables (node , anonymous_column_ids ())
130
- return typing .cast (nodes .ResultNode , result_node )
131
-
132
-
133
- @functools .lru_cache (maxsize = 5000 )
134
- def compile_node (node : nodes .BigFrameNode ) -> ir .SQLGlotIR :
135
- """Compiles node into CompileArrayValue. Caches result."""
136
- return node .reduce_up (lambda node , children : _compile_node (node , * children ))
137
-
138
-
139
- @functools .singledispatch
140
- def _compile_node (
141
- node : nodes .BigFrameNode , * compiled_children : ir .SQLGlotIR
142
- ) -> ir .SQLGlotIR :
143
- """Defines transformation but isn't cached, always use compile_node instead"""
144
- raise ValueError (f"Can't compile unrecognized node: { node } " )
123
+ @functools .singledispatchmethod
124
+ def _compile_node (
125
+ self , node : nodes .BigFrameNode , * compiled_children : ir .SQLGlotIR
126
+ ) -> ir .SQLGlotIR :
127
+ """Defines transformation but isn't cached, always use compile_node instead"""
128
+ raise ValueError (f"Can't compile unrecognized node: { node } " )
129
+
130
+ @_compile_node .register
131
+ def compile_readlocal (self , node : nodes .ReadLocalNode , * args ) -> ir .SQLGlotIR :
132
+ pa_table = node .local_data_source .data
133
+ pa_table = pa_table .select ([item .source_id for item in node .scan_list .items ])
134
+ pa_table = pa_table .rename_columns (
135
+ [item .id .sql for item in node .scan_list .items ]
136
+ )
145
137
138
+ offsets = node .offsets_col .sql if node .offsets_col else None
139
+ if offsets :
140
+ pa_table = pa_table .append_column (
141
+ offsets , pa .array (range (pa_table .num_rows ), type = pa .int64 ())
142
+ )
146
143
147
- @_compile_node .register
148
- def compile_readlocal (node : nodes .ReadLocalNode , * args ) -> ir .SQLGlotIR :
149
- pa_table = node .local_data_source .data
150
- pa_table = pa_table .select ([item .source_id for item in node .scan_list .items ])
151
- pa_table = pa_table .rename_columns ([item .id .sql for item in node .scan_list .items ])
144
+ return ir .SQLGlotIR .from_pyarrow (pa_table , node .schema , uid_gen = self .uid_gen )
152
145
153
- offsets = node .offsets_col .sql if node .offsets_col else None
154
- if offsets :
155
- pa_table = pa_table .append_column (
156
- offsets , pa .array (range (pa_table .num_rows ), type = pa .int64 ())
146
+ @_compile_node .register
147
+ def compile_selection (
148
+ self , node : nodes .SelectionNode , child : ir .SQLGlotIR
149
+ ) -> ir .SQLGlotIR :
150
+ selected_cols : tuple [tuple [str , sge .Expression ], ...] = tuple (
151
+ (id .sql , scalar_compiler .compile_scalar_expression (expr ))
152
+ for expr , id in node .input_output_pairs
157
153
)
158
-
159
- return ir .SQLGlotIR .from_pyarrow (pa_table , node .schema )
154
+ return child .select (selected_cols )
160
155
161
156
162
- @_compile_node .register
163
- def compile_selection (node : nodes .SelectionNode , child : ir .SQLGlotIR ) -> ir .SQLGlotIR :
164
- select_cols : typing .Dict [str , sge .Expression ] = {
165
- id .name : scalar_compiler .compile_scalar_expression (expr )
166
- for expr , id in node .input_output_pairs
167
- }
168
- return child .select (select_cols )
157
+ def _replace_unsupported_ops (node : nodes .BigFrameNode ):
158
+ node = nodes .bottom_up (node , rewrite .rewrite_slice )
159
+ node = nodes .bottom_up (node , rewrite .rewrite_timedelta_expressions )
160
+ node = nodes .bottom_up (node , rewrite .rewrite_range_rolling )
161
+ return node
0 commit comments