Skip to content

Commit 3eadf75

Browse files
authored
refactor: add compile_readlocal for SQLGlotCompiler (#1663)
This change is adding compile_readlocal for SQLGlotCompiler class and also introducing snapshot as dev dependency for unit tests.
1 parent f442e7a commit 3eadf75

File tree

12 files changed

+176
-75
lines changed

12 files changed

+176
-75
lines changed

bigframes/core/compile/compiled.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,7 @@ def _reproject_to_table(self) -> UnorderedIR:
288288
def from_polars(
289289
cls, pa_table: pa.Table, schema: Sequence[bigquery.SchemaField]
290290
) -> UnorderedIR:
291-
# TODO: add offsets
292-
"""
293-
Builds an in-memory only (SQL only) expr from a pandas dataframe.
294-
295-
Assumed that the dataframe has unique string column names and bigframes-suppported
296-
dtypes.
297-
"""
291+
"""Builds an in-memory only (SQL only) expr from a pyarrow table."""
298292
import bigframes_vendored.ibis.backends.bigquery.datatypes as third_party_ibis_bqtypes
299293

300294
# derive the ibis schema from the original pandas schema

bigframes/core/compile/compiler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import bigframes_vendored.ibis.expr.api as ibis_api
2323
import bigframes_vendored.ibis.expr.datatypes as ibis_dtypes
2424
import bigframes_vendored.ibis.expr.types as ibis_types
25-
import google.cloud.bigquery
2625
import pyarrow as pa
2726

2827
from bigframes import dtypes, operations
@@ -169,7 +168,7 @@ def compile_readlocal(node: nodes.ReadLocalNode, *args):
169168
pa_table = node.local_data_source.data
170169
bq_schema = node.schema.to_bigquery()
171170

172-
pa_table = pa_table.select(list(item.source_id for item in node.scan_list.items))
171+
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
173172
pa_table = pa_table.rename_columns(
174173
{item.source_id: item.id.sql for item in node.scan_list.items}
175174
)
@@ -178,7 +177,6 @@ def compile_readlocal(node: nodes.ReadLocalNode, *args):
178177
pa_table = pa_table.append_column(
179178
offsets, pa.array(range(pa_table.num_rows), type=pa.int64())
180179
)
181-
bq_schema = (*bq_schema, google.cloud.bigquery.SchemaField(offsets, "INT64"))
182180
return compiled.UnorderedIR.from_polars(pa_table, bq_schema)
183181

184182

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,22 @@
1515

1616
import dataclasses
1717
import functools
18+
import itertools
1819
import typing
1920

20-
import google.cloud.bigquery as bigquery
21-
import sqlglot.expressions as sge
21+
from google.cloud import bigquery
22+
import pyarrow as pa
2223

23-
from bigframes.core import expression, nodes, rewrite
24+
from bigframes.core import expression, identifiers, nodes, rewrite
2425
from bigframes.core.compile import configs
25-
from bigframes.core.compile.sqlglot import sql_gen
26+
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2627
import bigframes.core.ordering as bf_ordering
2728

2829

2930
@dataclasses.dataclass(frozen=True)
3031
class SQLGlotCompiler:
3132
"""Compiles BigFrame nodes into SQL using SQLGlot."""
3233

33-
sql_gen = sql_gen.SQLGen()
34-
3534
def compile(
3635
self,
3736
node: nodes.BigFrameNode,
@@ -81,6 +80,7 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
8180
result_node = typing.cast(
8281
nodes.ResultNode, rewrite.column_pruning(result_node)
8382
)
83+
result_node = _remap_variables(result_node)
8484
sql = self._compile_result_node(result_node)
8585
return configs.CompileResult(
8686
sql, result_node.schema.to_bigquery(), result_node.order_by
@@ -89,6 +89,8 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
8989
ordering: typing.Optional[bf_ordering.RowOrdering] = result_node.order_by
9090
result_node = dataclasses.replace(result_node, order_by=None)
9191
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
92+
93+
result_node = _remap_variables(result_node)
9294
sql = self._compile_result_node(result_node)
9395
# Return the ordering iff no extra columns are needed to define the row order
9496
if ordering is not None:
@@ -103,9 +105,9 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
103105
)
104106

105107
def _compile_result_node(self, root: nodes.ResultNode) -> str:
106-
sqlglot_expr = compile_node(root.child)
108+
sqlglot_ir = compile_node(root.child)
107109
# TODO: add order_by, limit, and selections to sqlglot_expr
108-
return self.sql_gen.sql(sqlglot_expr)
110+
return sqlglot_ir.sql
109111

110112

111113
def _replace_unsupported_ops(node: nodes.BigFrameNode):
@@ -115,27 +117,52 @@ def _replace_unsupported_ops(node: nodes.BigFrameNode):
115117
return node
116118

117119

120+
def _remap_variables(node: nodes.ResultNode) -> nodes.ResultNode:
121+
"""Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs."""
122+
123+
def anonymous_column_ids() -> typing.Generator[identifiers.ColumnId, None, None]:
124+
for i in itertools.count():
125+
yield identifiers.ColumnId(name=f"bfcol_{i}")
126+
127+
result_node, _ = rewrite.remap_variables(node, anonymous_column_ids())
128+
return typing.cast(nodes.ResultNode, result_node)
129+
130+
118131
@functools.lru_cache(maxsize=5000)
119-
def compile_node(node: nodes.BigFrameNode) -> sge.Expression:
120-
"""Compile node into CompileArrayValue. Caches result."""
132+
def compile_node(node: nodes.BigFrameNode) -> ir.SQLGlotIR:
133+
"""Compiles node into CompileArrayValue. Caches result."""
121134
return node.reduce_up(lambda node, children: _compile_node(node, *children))
122135

123136

124137
@functools.singledispatch
125138
def _compile_node(
126-
node: nodes.BigFrameNode, *compiled_children: sge.Expression
127-
) -> sge.Expression:
139+
node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR
140+
) -> ir.SQLGlotIR:
128141
"""Defines transformation but isn't cached, always use compile_node instead"""
129142
raise ValueError(f"Can't compile unrecognized node: {node}")
130143

131144

132145
@_compile_node.register
133-
def compile_readlocal(node: nodes.ReadLocalNode, *args) -> sge.Expression:
134-
# TODO: add support for reading from local files
135-
return sge.select()
146+
def compile_readlocal(node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR:
147+
offsets = node.offsets_col.sql if node.offsets_col else None
148+
schema_names = node.schema.names
149+
schema_dtypes = node.schema.dtypes
150+
151+
pa_table = node.local_data_source.data
152+
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
153+
pa_table = pa_table.rename_columns(
154+
{item.source_id: item.id.sql for item in node.scan_list.items}
155+
)
156+
157+
if offsets:
158+
pa_table = pa_table.append_column(
159+
offsets, pa.array(range(pa_table.num_rows), type=pa.int64())
160+
)
161+
162+
return ir.SQLGlotIR.from_pandas(pa_table.to_pandas(), schema_names, schema_dtypes)
136163

137164

138165
@_compile_node.register
139-
def compile_selection(node: nodes.SelectionNode, child: sge.Expression):
166+
def compile_selection(node: nodes.SelectionNode, child: ir.SQLGlotIR):
140167
# TODO: add support for selection
141168
return child

bigframes/core/compile/sqlglot/sql_gen.py

Lines changed: 0 additions & 38 deletions
This file was deleted.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
import typing
19+
20+
import pandas as pd
21+
import sqlglot as sg
22+
import sqlglot.dialects.bigquery
23+
import sqlglot.expressions as sge
24+
25+
from bigframes import dtypes
26+
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
27+
28+
29+
@dataclasses.dataclass(frozen=True)
30+
class SQLGlotIR:
31+
"""Helper class to build SQLGlot Query and generate SQL string."""
32+
33+
expr: sge.Expression = sge.Expression()
34+
"""The SQLGlot expression representing the query."""
35+
36+
dialect = sqlglot.dialects.bigquery.BigQuery
37+
"""The SQL dialect used for generation."""
38+
39+
quoted: bool = True
40+
"""Whether to quote identifiers in the generated SQL."""
41+
42+
pretty: bool = True
43+
"""Whether to pretty-print the generated SQL."""
44+
45+
@property
46+
def sql(self) -> str:
47+
"""Generate SQL string from the given expression."""
48+
return self.expr.sql(dialect=self.dialect, pretty=self.pretty)
49+
50+
@classmethod
51+
def from_pandas(
52+
cls,
53+
pd_df: pd.DataFrame,
54+
schema_names: typing.Sequence[str],
55+
schema_dtypes: typing.Sequence[dtypes.Dtype],
56+
) -> SQLGlotIR:
57+
"""Builds SQLGlot expression from pyarrow table."""
58+
dtype_expr = sge.DataType(
59+
this=sge.DataType.Type.STRUCT,
60+
expressions=[
61+
sge.ColumnDef(
62+
this=sge.to_identifier(name, quoted=True),
63+
kind=sgt.SQLGlotType.from_bigframes_dtype(dtype),
64+
)
65+
for name, dtype in zip(schema_names, schema_dtypes)
66+
],
67+
nested=True,
68+
)
69+
data_expr = [
70+
sge.Tuple(
71+
expressions=tuple(
72+
_literal(
73+
value=value,
74+
dtype=sgt.SQLGlotType.from_bigframes_dtype(dtype),
75+
)
76+
for value, dtype in zip(row, schema_dtypes)
77+
)
78+
)
79+
for _, row in pd_df.iterrows()
80+
]
81+
expr = sge.Unnest(
82+
expressions=[
83+
sge.DataType(
84+
this=sge.DataType.Type.ARRAY,
85+
expressions=[dtype_expr],
86+
nested=True,
87+
values=data_expr,
88+
),
89+
],
90+
)
91+
return cls(expr=sg.select(sge.Star()).from_(expr))
92+
93+
94+
def _literal(value: typing.Any, dtype: str) -> sge.Expression:
95+
if value is None:
96+
return _cast(sge.Null(), dtype)
97+
98+
# TODO: handle other types like visit_DefaultLiteral
99+
return sge.convert(value)
100+
101+
102+
def _cast(arg, to) -> sge.Cast:
103+
return sge.Cast(this=arg, to=to)

bigframes/core/compile/sqlglot/sqlglot_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def from_bigframes_dtype(
3232
bigframes_dtype: typing.Union[
3333
bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype, np.dtype[typing.Any]
3434
],
35-
):
35+
) -> str:
3636
if bigframes_dtype == bigframes.dtypes.INT_DTYPE:
3737
return "INT64"
3838
elif bigframes_dtype == bigframes.dtypes.FLOAT_DTYPE:

bigframes/core/nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,11 +1636,11 @@ def remap_vars(
16361636
def remap_refs(
16371637
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
16381638
) -> ResultNode:
1639-
output_names = tuple(
1639+
output_cols = tuple(
16401640
(ref.remap_column_refs(mappings), name) for ref, name in self.output_cols
16411641
)
16421642
order_by = self.order_by.remap_column_refs(mappings) if self.order_by else None
1643-
return dataclasses.replace(self, output_names=output_names, order_by=order_by) # type: ignore
1643+
return dataclasses.replace(self, output_cols=output_cols, order_by=order_by) # type: ignore
16441644

16451645
@property
16461646
def consumed_ids(self) -> COLUMN_SET:

noxfile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
]
7777
UNIT_TEST_LOCAL_DEPENDENCIES: List[str] = []
7878
UNIT_TEST_DEPENDENCIES: List[str] = []
79-
UNIT_TEST_EXTRAS: List[str] = []
79+
UNIT_TEST_EXTRAS: List[str] = ["tests"]
8080
UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {
8181
"3.12": ["polars", "scikit-learn"],
8282
}
@@ -203,7 +203,7 @@ def install_unittest_dependencies(session, install_test_extra, *constraints):
203203

204204
if install_test_extra and UNIT_TEST_EXTRAS_BY_PYTHON:
205205
extras = UNIT_TEST_EXTRAS_BY_PYTHON.get(session.python, [])
206-
elif install_test_extra and UNIT_TEST_EXTRAS:
206+
if install_test_extra and UNIT_TEST_EXTRAS:
207207
extras = UNIT_TEST_EXTRAS
208208
else:
209209
extras = []

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
]
7373
extras = {
7474
# Optional test dependencies packages. If they're missed, may skip some tests.
75-
"tests": [],
75+
"tests": ["freezegun", "pytest-snapshot"],
7676
# used for local engine, which is only needed for unit tests at present.
7777
"polars": ["polars >= 1.7.0"],
7878
"scikit-learn": ["scikit-learn>=1.2.2"],
@@ -82,7 +82,6 @@
8282
"pre-commit",
8383
"nox",
8484
"google-cloud-testutils",
85-
"freezegun",
8685
],
8786
}
8887
extras["all"] = list(sorted(frozenset(itertools.chain.from_iterable(extras.values()))))

tests/unit/core/compile/sqlglot/conftest.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,15 @@ def compiler_session():
2626
@pytest.fixture(scope="module")
2727
def all_types_df() -> pd.DataFrame:
2828
# TODO: all types pandas dataframes
29-
return pd.DataFrame({})
29+
# TODO: add tests for empty dataframes
30+
df = pd.DataFrame(
31+
{
32+
"int1": pd.Series([1, 2, 3], dtype="Int64"),
33+
"int2": pd.Series([-10, 20, 30], dtype="Int64"),
34+
"bools": pd.Series([True, None, False], dtype="boolean"),
35+
"strings": pd.Series(["b", "aa", "ccc"], dtype="string[pyarrow]"),
36+
},
37+
)
38+
# add more complexity index.
39+
df.index = df.index.astype("Int64")
40+
return df
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
*
3+
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` INT64, `bfcol_1` INT64, `bfcol_2` BOOLEAN, `bfcol_3` STRING, `bfcol_4` INT64>>[(1, -10, TRUE, 'b', 0), (2, 20, CAST(NULL AS BOOLEAN), 'aa', 1), (3, 30, FALSE, 'ccc', 2)])

0 commit comments

Comments
 (0)