Skip to content

Commit 0adf548

Browse files
refactor: Simplify compile code paths (#1420)
1 parent 521e987 commit 0adf548

File tree

8 files changed

+86
-121
lines changed

8 files changed

+86
-121
lines changed

bigframes/core/array_value.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,6 @@ def as_cached(
198198
)
199199
return ArrayValue(node)
200200

201-
def _try_evaluate_local(self):
202-
"""Use only for unit testing paths - not fully featured. Will throw exception if fails."""
203-
import bigframes.core.compile
204-
205-
return bigframes.core.compile.test_only_try_evaluate(self.node)
206-
207201
def get_column_type(self, key: str) -> bigframes.dtypes.Dtype:
208202
return self.schema.get_type(key)
209203

bigframes/core/blocks.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,10 @@ def index(self) -> BlockIndexProperties:
213213
@functools.cached_property
214214
def shape(self) -> typing.Tuple[int, int]:
215215
"""Returns dimensions as (length, width) tuple."""
216-
217-
row_count_expr = self.expr.row_count()
218-
219-
# Support in-memory engines for hermetic unit tests.
220-
if self.expr.session is None:
216+
# Support zero-query for hermetic unit tests.
217+
if self.expr.session is None and self.expr.node.row_count:
221218
try:
222-
row_count = row_count_expr._try_evaluate_local().squeeze()
223-
return (row_count, len(self.value_columns))
219+
return self.expr.node.row_count
224220
except Exception:
225221
pass
226222

bigframes/core/compile/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,9 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from bigframes.core.compile.api import (
17-
SQLCompiler,
18-
test_only_ibis_inferred_schema,
19-
test_only_try_evaluate,
20-
)
16+
from bigframes.core.compile.api import SQLCompiler, test_only_ibis_inferred_schema
2117

2218
__all__ = [
2319
"SQLCompiler",
24-
"test_only_try_evaluate",
2520
"test_only_ibis_inferred_schema",
2621
]

bigframes/core/compile/api.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import Mapping, Sequence, Tuple, TYPE_CHECKING
16+
from typing import Optional, Sequence, Tuple, TYPE_CHECKING
1717

1818
import google.cloud.bigquery as bigquery
1919

20-
import bigframes.core.compile.compiler as compiler
20+
from bigframes.core import rewrite
21+
from bigframes.core.compile import compiler
2122

2223
if TYPE_CHECKING:
2324
import bigframes.core.nodes
@@ -31,31 +32,16 @@ class SQLCompiler:
3132
def __init__(self, strict: bool = True):
3233
self._compiler = compiler.Compiler(strict=strict)
3334

34-
def compile_peek(self, node: bigframes.core.nodes.BigFrameNode, n_rows: int) -> str:
35-
"""Compile node into sql that selects N arbitrary rows, may not execute deterministically."""
36-
return self._compiler.compile_peek_sql(node, n_rows)
37-
38-
def compile_unordered(
39-
self,
40-
node: bigframes.core.nodes.BigFrameNode,
41-
*,
42-
col_id_overrides: Mapping[str, str] = {},
43-
) -> str:
44-
"""Compile node into sql where rows are unsorted, and no ordering information is preserved."""
45-
# TODO: Enable limit pullup, but only if not being used to write to clustered table.
46-
output_ids = [col_id_overrides.get(id, id) for id in node.schema.names]
47-
return self._compiler.compile_sql(node, ordered=False, output_ids=output_ids)
48-
49-
def compile_ordered(
35+
def compile(
5036
self,
5137
node: bigframes.core.nodes.BigFrameNode,
5238
*,
53-
col_id_overrides: Mapping[str, str] = {},
39+
ordered: bool = True,
40+
limit: Optional[int] = None,
5441
) -> str:
5542
"""Compile node into sql where rows are sorted with ORDER BY."""
5643
# If we are ordering the query anyways, compiling the slice as a limit is probably a good idea.
57-
output_ids = [col_id_overrides.get(id, id) for id in node.schema.names]
58-
return self._compiler.compile_sql(node, ordered=True, output_ids=output_ids)
44+
return self._compiler.compile_sql(node, ordered=ordered, limit=limit)
5945

6046
def compile_raw(
6147
self,
@@ -67,21 +53,15 @@ def compile_raw(
6753
return self._compiler.compile_raw(node)
6854

6955

70-
def test_only_try_evaluate(node: bigframes.core.nodes.BigFrameNode):
71-
"""Use only for unit testing paths - not fully featured. Will throw exception if fails."""
72-
node = _STRICT_COMPILER._preprocess(node)
73-
ibis = _STRICT_COMPILER.compile_node(node)._to_ibis_expr()
74-
return ibis.pandas.connect({}).execute(ibis)
75-
76-
7756
def test_only_ibis_inferred_schema(node: bigframes.core.nodes.BigFrameNode):
7857
"""Use only for testing paths to ensure ibis inferred schema does not diverge from bigframes inferred schema."""
7958
import bigframes.core.schema
8059

81-
node = _STRICT_COMPILER._preprocess(node)
82-
compiled = _STRICT_COMPILER.compile_node(node)
60+
node = _STRICT_COMPILER._replace_unsupported_ops(node)
61+
node, _ = rewrite.pull_up_order(node, order_root=False)
62+
ir = _STRICT_COMPILER.compile_node(node)
8363
items = tuple(
84-
bigframes.core.schema.SchemaItem(name, compiled.get_column_type(ibis_id))
85-
for name, ibis_id in zip(node.schema.names, compiled.column_ids)
64+
bigframes.core.schema.SchemaItem(name, ir.get_column_type(ibis_id))
65+
for name, ibis_id in zip(node.schema.names, ir.column_ids)
8666
)
8767
return bigframes.core.schema.ArraySchema(items)

bigframes/core/compile/compiled.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@ def to_sql(
7272
) -> str:
7373
ibis_table = self._to_ibis_expr()
7474
# This set of output transforms maybe should be its own output node??
75-
if order_by or limit:
75+
if (
76+
order_by
77+
or limit
78+
or (selections and (tuple(selections) != tuple(self.column_ids)))
79+
):
7680
sql = ibis_bigquery.Backend().compile(ibis_table)
7781
sql = (
7882
bigframes.core.compile.googlesql.Select()

bigframes/core/compile/compiler.py

Lines changed: 33 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
import bigframes.core.compile.ibis_types
3434
import bigframes.core.compile.scalar_op_compiler as compile_scalar
3535
import bigframes.core.compile.schema_translator
36-
import bigframes.core.expression as ex
37-
import bigframes.core.identifiers as ids
3836
import bigframes.core.nodes as nodes
3937
import bigframes.core.ordering as bf_ordering
4038
import bigframes.core.rewrite as rewrites
@@ -52,65 +50,54 @@ class Compiler:
5250
scalar_op_compiler = compile_scalar.ScalarOpCompiler()
5351

5452
def compile_sql(
55-
self, node: nodes.BigFrameNode, ordered: bool, output_ids: typing.Sequence[str]
53+
self,
54+
node: nodes.BigFrameNode,
55+
ordered: bool,
56+
limit: typing.Optional[int] = None,
5657
) -> str:
57-
# TODO: get rid of output_ids arg
58-
assert len(output_ids) == len(list(node.fields))
59-
node = set_output_names(node, output_ids)
60-
node = nodes.top_down(node, rewrites.rewrite_timedelta_expressions)
58+
# later steps might add ids, so snapshot before those steps.
59+
output_ids = node.schema.names
6160
if ordered:
62-
node, limit = rewrites.pullup_limit_from_slice(node)
63-
node = nodes.bottom_up(node, rewrites.rewrite_slice)
64-
# TODO: Extract out CTEs
65-
node, ordering = rewrites.pull_up_order(
66-
node, order_root=True, ordered_joins=self.strict
67-
)
68-
node = rewrites.column_pruning(node)
69-
ir = self.compile_node(node)
70-
return ir.to_sql(
71-
order_by=ordering.all_ordering_columns,
72-
limit=limit,
73-
selections=output_ids,
74-
)
75-
else:
76-
node = nodes.bottom_up(node, rewrites.rewrite_slice)
77-
node, _ = rewrites.pull_up_order(
78-
node, order_root=False, ordered_joins=self.strict
79-
)
80-
node = rewrites.column_pruning(node)
81-
ir = self.compile_node(node)
82-
return ir.to_sql(selections=output_ids)
61+
# Need to do this before replacing unsupported ops, as that will rewrite slice ops
62+
node, pulled_up_limit = rewrites.pullup_limit_from_slice(node)
63+
if (pulled_up_limit is not None) and (
64+
(limit is None) or limit > pulled_up_limit
65+
):
66+
limit = pulled_up_limit
8367

84-
def compile_peek_sql(self, node: nodes.BigFrameNode, n_rows: int) -> str:
85-
ids = [id.sql for id in node.ids]
86-
node = nodes.bottom_up(node, rewrites.rewrite_slice)
87-
node = nodes.top_down(node, rewrites.rewrite_timedelta_expressions)
88-
node, _ = rewrites.pull_up_order(
89-
node, order_root=False, ordered_joins=self.strict
68+
node = self._replace_unsupported_ops(node)
69+
# prune before pulling up order to avoid unnnecessary row_number() ops
70+
node = rewrites.column_pruning(node)
71+
node, ordering = rewrites.pull_up_order(
72+
node, order_root=ordered, ordered_joins=self.strict
9073
)
74+
# final pruning to cleanup up any leftovers unused values
9175
node = rewrites.column_pruning(node)
92-
return self.compile_node(node).to_sql(limit=n_rows, selections=ids)
76+
return self.compile_node(node).to_sql(
77+
order_by=ordering.all_ordering_columns if ordered else (),
78+
limit=limit,
79+
selections=output_ids,
80+
)
9381

9482
def compile_raw(
9583
self,
96-
node: bigframes.core.nodes.BigFrameNode,
84+
node: nodes.BigFrameNode,
9785
) -> typing.Tuple[
9886
str, typing.Sequence[google.cloud.bigquery.SchemaField], bf_ordering.RowOrdering
9987
]:
100-
node = nodes.bottom_up(node, rewrites.rewrite_slice)
101-
node = nodes.top_down(node, rewrites.rewrite_timedelta_expressions)
102-
node, ordering = rewrites.pull_up_order(node, ordered_joins=self.strict)
88+
node = self._replace_unsupported_ops(node)
89+
node = rewrites.column_pruning(node)
90+
node, ordering = rewrites.pull_up_order(
91+
node, order_root=True, ordered_joins=self.strict
92+
)
10393
node = rewrites.column_pruning(node)
104-
ir = self.compile_node(node)
105-
sql = ir.to_sql()
94+
sql = self.compile_node(node).to_sql()
10695
return sql, node.schema.to_bigquery(), ordering
10796

108-
def _preprocess(self, node: nodes.BigFrameNode):
97+
def _replace_unsupported_ops(self, node: nodes.BigFrameNode):
98+
# TODO: Run all replacement rules as single bottom-up pass
10999
node = nodes.bottom_up(node, rewrites.rewrite_slice)
110-
node = nodes.top_down(node, rewrites.rewrite_timedelta_expressions)
111-
node, _ = rewrites.pull_up_order(
112-
node, order_root=False, ordered_joins=self.strict
113-
)
100+
node = nodes.bottom_up(node, rewrites.rewrite_timedelta_expressions)
114101
return node
115102

116103
# TODO: Remove cache when schema no longer requires compilation to derive schema (and therefor only compiles for execution)
@@ -305,16 +292,3 @@ def compile_explode(self, node: nodes.ExplodeNode):
305292
@_compile_node.register
306293
def compile_random_sample(self, node: nodes.RandomSampleNode):
307294
return self.compile_node(node.child)._uniform_sampling(node.fraction)
308-
309-
310-
def set_output_names(
311-
node: bigframes.core.nodes.BigFrameNode, output_ids: typing.Sequence[str]
312-
):
313-
# TODO: Create specialized output operators that will handle final names
314-
return nodes.SelectionNode(
315-
node,
316-
tuple(
317-
bigframes.core.nodes.AliasedRef(ex.DerefOp(old_id), ids.ColumnId(out_id))
318-
for old_id, out_id in zip(node.ids, output_ids)
319-
),
320-
)

bigframes/core/nodes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,6 +1555,14 @@ def remap_refs(
15551555
return dataclasses.replace(self, column_ids=new_ids) # type: ignore
15561556

15571557

1558+
# Introduced during planing/compilation
1559+
@dataclasses.dataclass(frozen=True, eq=False)
1560+
class ResultNode(UnaryNode):
1561+
output_names: tuple[str, ...]
1562+
order_by: Tuple[OrderingExpression, ...] = ()
1563+
limit: Optional[int] = None
1564+
1565+
15581566
# Tree operators
15591567
def top_down(
15601568
root: BigFrameNode,

bigframes/session/executor.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import pyarrow
4141

4242
import bigframes.core
43+
from bigframes.core import expression
4344
import bigframes.core.compile
4445
import bigframes.core.guid
4546
import bigframes.core.identifiers
@@ -231,11 +232,9 @@ def to_sql(
231232
if enable_cache
232233
else array_value.node
233234
)
234-
if ordered:
235-
return self.compiler.compile_ordered(
236-
node, col_id_overrides=col_id_overrides
237-
)
238-
return self.compiler.compile_unordered(node, col_id_overrides=col_id_overrides)
235+
if col_id_overrides:
236+
node = override_ids(node, col_id_overrides)
237+
return self.compiler.compile(node, ordered=ordered)
239238

240239
def execute(
241240
self,
@@ -377,7 +376,7 @@ def peek(
377376
msg = "Peeking this value cannot be done efficiently."
378377
warnings.warn(msg)
379378

380-
sql = self.compiler.compile_peek(plan, n_rows)
379+
sql = self.compiler.compile(plan, ordered=False, limit=n_rows)
381380

382381
# TODO(swast): plumb through the api_name of the user-facing api that
383382
# caused this query.
@@ -416,7 +415,7 @@ def head(
416415
assert tree_properties.can_fast_head(plan)
417416

418417
head_plan = generate_head_plan(plan, n_rows)
419-
sql = self.compiler.compile_ordered(head_plan)
418+
sql = self.compiler.compile(head_plan)
420419

421420
# TODO(swast): plumb through the api_name of the user-facing api that
422421
# caused this query.
@@ -439,7 +438,7 @@ def get_row_count(self, array_value: bigframes.core.ArrayValue) -> int:
439438
row_count_plan = self.replace_cached_subtrees(
440439
generate_row_count_plan(array_value.node)
441440
)
442-
sql = self.compiler.compile_unordered(row_count_plan)
441+
sql = self.compiler.compile(row_count_plan, ordered=False)
443442
iter, _ = self._run_execute_query(sql)
444443
return next(iter)[0]
445444

@@ -549,8 +548,8 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
549548
"""Executes the query and uses the resulting table to rewrite future executions."""
550549
offset_column = bigframes.core.guid.generate_guid("bigframes_offsets")
551550
w_offsets, offset_column = array_value.promote_offsets()
552-
sql = self.compiler.compile_unordered(
553-
self.replace_cached_subtrees(w_offsets.node)
551+
sql = self.compiler.compile(
552+
self.replace_cached_subtrees(w_offsets.node), ordered=False
554553
)
555554

556555
tmp_table = self._sql_as_cached_temp_table(
@@ -666,3 +665,18 @@ def generate_head_plan(node: nodes.BigFrameNode, n: int):
666665

667666
def generate_row_count_plan(node: nodes.BigFrameNode):
668667
return nodes.RowCountNode(node)
668+
669+
670+
def override_ids(
671+
node: nodes.BigFrameNode, col_id_overrides: Mapping[str, str]
672+
) -> nodes.SelectionNode:
673+
output_ids = [col_id_overrides.get(id, id) for id in node.schema.names]
674+
return nodes.SelectionNode(
675+
node,
676+
tuple(
677+
nodes.AliasedRef(
678+
expression.DerefOp(old_id), bigframes.core.identifiers.ColumnId(out_id)
679+
)
680+
for old_id, out_id in zip(node.ids, output_ids)
681+
),
682+
)

0 commit comments

Comments
 (0)