1515
1616import dataclasses
1717import functools
18- import itertools
1918import typing
2019
2120from google .cloud import bigquery
2221import pyarrow as pa
2322import sqlglot .expressions as sge
2423
25- from bigframes .core import expression , identifiers , nodes , rewrite
24+ from bigframes .core import expression , guid , identifiers , nodes , rewrite
2625from bigframes .core .compile import configs
2726import bigframes .core .compile .sqlglot .scalar_compiler as scalar_compiler
2827import bigframes .core .compile .sqlglot .sqlglot_ir as ir
2928import bigframes .core .ordering as bf_ordering
3029
3130
32- @dataclasses .dataclass (frozen = True )
3331class SQLGlotCompiler :
3432 """Compiles BigFrame nodes into SQL using SQLGlot."""
3533
34+ uid_gen : guid .SequentialUIDGenerator
35+ """Generator for unique identifiers."""
36+
37+ def __init__ (self ):
38+ self .uid_gen = guid .SequentialUIDGenerator ()
39+
3640 def compile (
3741 self ,
3842 node : nodes .BigFrameNode ,
@@ -82,7 +86,7 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
8286 result_node = typing .cast (
8387 nodes .ResultNode , rewrite .column_pruning (result_node )
8488 )
85- result_node = _remap_variables (result_node )
89+ result_node = self . _remap_variables (result_node )
8690 sql = self ._compile_result_node (result_node )
8791 return configs .CompileResult (
8892 sql , result_node .schema .to_bigquery (), result_node .order_by
@@ -92,7 +96,7 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
9296 result_node = dataclasses .replace (result_node , order_by = None )
9397 result_node = typing .cast (nodes .ResultNode , rewrite .column_pruning (result_node ))
9498
95- result_node = _remap_variables (result_node )
99+ result_node = self . _remap_variables (result_node )
96100 sql = self ._compile_result_node (result_node )
97101 # Return the ordering iff no extra columns are needed to define the row order
98102 if ordering is not None :
@@ -106,63 +110,62 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
106110 sql , result_node .schema .to_bigquery (), output_order
107111 )
108112
113+ def _remap_variables (self , node : nodes .ResultNode ) -> nodes .ResultNode :
114+ """Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs."""
115+
116+ result_node , _ = rewrite .remap_variables (
117+ node , map (identifiers .ColumnId , self .uid_gen .get_uid_stream ("bfcol_" ))
118+ )
119+ return typing .cast (nodes .ResultNode , result_node )
120+
109121 def _compile_result_node (self , root : nodes .ResultNode ) -> str :
110- sqlglot_ir = compile_node (root .child )
122+ sqlglot_ir = self . compile_node (root .child )
111123 # TODO: add order_by, limit, and selections to sqlglot_expr
112124 return sqlglot_ir .sql
113125
126+ @functools .lru_cache (maxsize = 5000 )
127+ def compile_node (self , node : nodes .BigFrameNode ) -> ir .SQLGlotIR :
128+ """Compiles node into CompileArrayValue. Caches result."""
129+ return node .reduce_up (
130+ lambda node , children : self ._compile_node (node , * children )
131+ )
114132
115- def _replace_unsupported_ops (node : nodes .BigFrameNode ):
116- node = nodes .bottom_up (node , rewrite .rewrite_slice )
117- node = nodes .bottom_up (node , rewrite .rewrite_timedelta_expressions )
118- node = nodes .bottom_up (node , rewrite .rewrite_range_rolling )
119- return node
120-
121-
122- def _remap_variables (node : nodes .ResultNode ) -> nodes .ResultNode :
123- """Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs."""
124-
125- def anonymous_column_ids () -> typing .Generator [identifiers .ColumnId , None , None ]:
126- for i in itertools .count ():
127- yield identifiers .ColumnId (name = f"bfcol_{ i } " )
128-
129- result_node , _ = rewrite .remap_variables (node , anonymous_column_ids ())
130- return typing .cast (nodes .ResultNode , result_node )
131-
132-
133- @functools .lru_cache (maxsize = 5000 )
134- def compile_node (node : nodes .BigFrameNode ) -> ir .SQLGlotIR :
135- """Compiles node into CompileArrayValue. Caches result."""
136- return node .reduce_up (lambda node , children : _compile_node (node , * children ))
137-
138-
139- @functools .singledispatch
140- def _compile_node (
141- node : nodes .BigFrameNode , * compiled_children : ir .SQLGlotIR
142- ) -> ir .SQLGlotIR :
143- """Defines transformation but isn't cached, always use compile_node instead"""
144- raise ValueError (f"Can't compile unrecognized node: { node } " )
133+ @functools .singledispatchmethod
134+ def _compile_node (
135+ self , node : nodes .BigFrameNode , * compiled_children : ir .SQLGlotIR
136+ ) -> ir .SQLGlotIR :
137+ """Defines transformation but isn't cached, always use compile_node instead"""
138+ raise ValueError (f"Can't compile unrecognized node: { node } " )
139+
140+ @_compile_node .register
141+ def compile_readlocal (self , node : nodes .ReadLocalNode , * args ) -> ir .SQLGlotIR :
142+ pa_table = node .local_data_source .data
143+ pa_table = pa_table .select ([item .source_id for item in node .scan_list .items ])
144+ pa_table = pa_table .rename_columns (
145+ [item .id .sql for item in node .scan_list .items ]
146+ )
145147
148+ offsets = node .offsets_col .sql if node .offsets_col else None
149+ if offsets :
150+ pa_table = pa_table .append_column (
151+ offsets , pa .array (range (pa_table .num_rows ), type = pa .int64 ())
152+ )
146153
147- @_compile_node .register
148- def compile_readlocal (node : nodes .ReadLocalNode , * args ) -> ir .SQLGlotIR :
149- pa_table = node .local_data_source .data
150- pa_table = pa_table .select ([item .source_id for item in node .scan_list .items ])
151- pa_table = pa_table .rename_columns ([item .id .sql for item in node .scan_list .items ])
154+ return ir .SQLGlotIR .from_pyarrow (pa_table , node .schema , uid_gen = self .uid_gen )
152155
153- offsets = node .offsets_col .sql if node .offsets_col else None
154- if offsets :
155- pa_table = pa_table .append_column (
156- offsets , pa .array (range (pa_table .num_rows ), type = pa .int64 ())
156+ @_compile_node .register
157+ def compile_selection (
158+ self , node : nodes .SelectionNode , child : ir .SQLGlotIR
159+ ) -> ir .SQLGlotIR :
160+ selected_cols : tuple [tuple [str , sge .Expression ], ...] = tuple (
161+ (id .sql , scalar_compiler .compile_scalar_expression (expr ))
162+ for expr , id in node .input_output_pairs
157163 )
164+ return child .select (selected_cols )
158165
159- return ir .SQLGlotIR .from_pyarrow (pa_table , node .schema )
160166
161-
162- @_compile_node .register
163- def compile_selection (node : nodes .SelectionNode , child : ir .SQLGlotIR ) -> ir .SQLGlotIR :
164- select_cols : typing .Dict [str , sge .Expression ] = {
165- id .name : scalar_compiler .compile_scalar_expression (expr )
166- for expr , id in node .input_output_pairs
167- }
168- return child .select (select_cols )
167+ def _replace_unsupported_ops (node : nodes .BigFrameNode ):
168+ node = nodes .bottom_up (node , rewrite .rewrite_slice )
169+ node = nodes .bottom_up (node , rewrite .rewrite_timedelta_expressions )
170+ node = nodes .bottom_up (node , rewrite .rewrite_range_rolling )
171+ return node
0 commit comments