Skip to content

Commit c958dbe

Browse files
perf: Fold row count ops when known (#1656)
1 parent 3eadf75 commit c958dbe

15 files changed

+160
-149
lines changed

bigframes/core/array_value.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,17 @@ def get_column_type(self, key: str) -> bigframes.dtypes.Dtype:
204204

205205
def row_count(self) -> ArrayValue:
206206
"""Get number of rows in ArrayValue as a single-entry ArrayValue."""
207-
return ArrayValue(nodes.RowCountNode(child=self.node))
207+
return ArrayValue(
208+
nodes.AggregateNode(
209+
child=self.node,
210+
aggregations=(
211+
(
212+
ex.NullaryAggregation(agg_ops.size_op),
213+
ids.ColumnId(bigframes.core.guid.generate_guid()),
214+
),
215+
),
216+
)
217+
)
208218

209219
# Operations
210220
def filter_by_id(self, predicate_id: str, keep_null: bool = False) -> ArrayValue:

bigframes/core/blocks.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
Optional,
4040
Sequence,
4141
Tuple,
42-
TYPE_CHECKING,
4342
Union,
4443
)
4544
import warnings
@@ -68,13 +67,8 @@
6867
import bigframes.core.window_spec as windows
6968
import bigframes.dtypes
7069
import bigframes.exceptions as bfe
71-
import bigframes.features
7270
import bigframes.operations as ops
7371
import bigframes.operations.aggregations as agg_ops
74-
import bigframes.session._io.pandas as io_pandas
75-
76-
if TYPE_CHECKING:
77-
import bigframes.session.executor
7872

7973
# Type constraint for wherever column labels are used
8074
Label = typing.Hashable
@@ -221,7 +215,7 @@ def shape(self) -> typing.Tuple[int, int]:
221215
except Exception:
222216
pass
223217

224-
row_count = self.session._executor.get_row_count(self.expr)
218+
row_count = self.session._executor.execute(self.expr.row_count()).to_py_scalar()
225219
return (row_count, len(self.value_columns))
226220

227221
@property
@@ -485,7 +479,7 @@ def to_arrow(
485479
*,
486480
ordered: bool = True,
487481
allow_large_results: Optional[bool] = None,
488-
) -> Tuple[pa.Table, bigquery.QueryJob]:
482+
) -> Tuple[pa.Table, Optional[bigquery.QueryJob]]:
489483
"""Run query and download results as a pyarrow Table."""
490484
execute_result = self.session._executor.execute(
491485
self.expr, ordered=ordered, use_explicit_destination=allow_large_results
@@ -580,7 +574,7 @@ def try_peek(
580574
result = self.session._executor.peek(
581575
self.expr, n, use_explicit_destination=allow_large_results
582576
)
583-
df = io_pandas.arrow_to_pandas(result.to_arrow_table(), self.expr.schema)
577+
df = result.to_pandas()
584578
self._copy_index_to_pandas(df)
585579
return df
586580
else:
@@ -604,8 +598,7 @@ def to_pandas_batches(
604598
page_size=page_size,
605599
max_results=max_results,
606600
)
607-
for record_batch in execute_result.arrow_batches():
608-
df = io_pandas.arrow_to_pandas(record_batch, self.expr.schema)
601+
for df in execute_result.to_pandas_batches():
609602
self._copy_index_to_pandas(df)
610603
if squeeze:
611604
yield df.squeeze(axis=1)
@@ -659,7 +652,7 @@ def _materialize_local(
659652

660653
# TODO: Maybe materialize before downsampling
661654
# Some downsampling methods
662-
if fraction < 1:
655+
if fraction < 1 and (execute_result.total_rows is not None):
663656
if not sample_config.enable_downsampling:
664657
raise RuntimeError(
665658
f"The data size ({table_mb:.2f} MB) exceeds the maximum download limit of "
@@ -690,9 +683,7 @@ def _materialize_local(
690683
MaterializationOptions(ordered=materialize_options.ordered)
691684
)
692685
else:
693-
total_rows = execute_result.total_rows
694-
arrow = execute_result.to_arrow_table()
695-
df = io_pandas.arrow_to_pandas(arrow, schema=self.expr.schema)
686+
df = execute_result.to_pandas()
696687
self._copy_index_to_pandas(df)
697688

698689
return df, execute_result.query_job
@@ -1570,12 +1561,11 @@ def retrieve_repr_request_results(
15701561

15711562
# head caches full underlying expression, so row_count will be free after
15721563
head_result = self.session._executor.head(self.expr, max_results)
1573-
count = self.session._executor.get_row_count(self.expr)
1564+
row_count = self.session._executor.execute(self.expr.row_count()).to_py_scalar()
15741565

1575-
arrow = head_result.to_arrow_table()
1576-
df = io_pandas.arrow_to_pandas(arrow, schema=self.expr.schema)
1566+
df = head_result.to_pandas()
15771567
self._copy_index_to_pandas(df)
1578-
return df, count, head_result.query_job
1568+
return df, row_count, head_result.query_job
15791569

15801570
def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]:
15811571
expr, result_id = self._expr.promote_offsets()

bigframes/core/compile/compiler.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,7 @@ def compile_readlocal(node: nodes.ReadLocalNode, *args):
169169
bq_schema = node.schema.to_bigquery()
170170

171171
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
172-
pa_table = pa_table.rename_columns(
173-
{item.source_id: item.id.sql for item in node.scan_list.items}
174-
)
172+
pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items])
175173

176174
if offsets:
177175
pa_table = pa_table.append_column(
@@ -254,11 +252,6 @@ def compile_concat(node: nodes.ConcatNode, *children: compiled.UnorderedIR):
254252
return concat_impl.concat_unordered(children, output_ids)
255253

256254

257-
@_compile_node.register
258-
def compile_rowcount(node: nodes.RowCountNode, child: compiled.UnorderedIR):
259-
return child.row_count(name=node.col_id.sql)
260-
261-
262255
@_compile_node.register
263256
def compile_aggregate(node: nodes.AggregateNode, child: compiled.UnorderedIR):
264257
aggs = tuple((agg, id.sql) for agg, id in node.aggregations)

bigframes/core/compile/polars/compiler.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,6 @@ def compile_projection(self, node: nodes.ProjectionNode):
252252
]
253253
return self.compile_node(node.child).with_columns(new_cols)
254254

255-
@compile_node.register
256-
def compile_rowcount(self, node: nodes.RowCountNode):
257-
df = cast(pl.LazyFrame, self.compile_node(node.child))
258-
return df.select(pl.len().alias(node.col_id.sql))
259-
260255
@compile_node.register
261256
def compile_offsets(self, node: nodes.PromoteOffsetsNode):
262257
return self.compile_node(node.child).with_columns(

bigframes/core/nodes.py

Lines changed: 13 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,55 +1256,6 @@ def remap_refs(
12561256
return dataclasses.replace(self, assignments=new_fields)
12571257

12581258

1259-
# TODO: Merge RowCount into Aggregate Node?
1260-
# Row count can be compute from table metadata sometimes, so it is a bit special.
1261-
@dataclasses.dataclass(frozen=True, eq=False)
1262-
class RowCountNode(UnaryNode):
1263-
col_id: identifiers.ColumnId = identifiers.ColumnId("count")
1264-
1265-
@property
1266-
def row_preserving(self) -> bool:
1267-
return False
1268-
1269-
@property
1270-
def non_local(self) -> bool:
1271-
return True
1272-
1273-
@property
1274-
def fields(self) -> Sequence[Field]:
1275-
return (Field(self.col_id, bigframes.dtypes.INT_DTYPE, nullable=False),)
1276-
1277-
@property
1278-
def variables_introduced(self) -> int:
1279-
return 1
1280-
1281-
@property
1282-
def defines_namespace(self) -> bool:
1283-
return True
1284-
1285-
@property
1286-
def row_count(self) -> Optional[int]:
1287-
return 1
1288-
1289-
@property
1290-
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
1291-
return (self.col_id,)
1292-
1293-
@property
1294-
def consumed_ids(self) -> COLUMN_SET:
1295-
return frozenset()
1296-
1297-
def remap_vars(
1298-
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1299-
) -> RowCountNode:
1300-
return dataclasses.replace(self, col_id=mappings.get(self.col_id, self.col_id))
1301-
1302-
def remap_refs(
1303-
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1304-
) -> RowCountNode:
1305-
return self
1306-
1307-
13081259
@dataclasses.dataclass(frozen=True, eq=False)
13091260
class AggregateNode(UnaryNode):
13101261
aggregations: typing.Tuple[typing.Tuple[ex.Aggregation, identifiers.ColumnId], ...]
@@ -1642,6 +1593,19 @@ def remap_refs(
16421593
order_by = self.order_by.remap_column_refs(mappings) if self.order_by else None
16431594
return dataclasses.replace(self, output_cols=output_cols, order_by=order_by) # type: ignore
16441595

1596+
@property
1597+
def fields(self) -> Sequence[Field]:
1598+
# Fields property here is for output schema, not to be consumed by a parent node.
1599+
input_fields_by_id = {field.id: field for field in self.child.fields}
1600+
return tuple(
1601+
Field(
1602+
identifiers.ColumnId(output),
1603+
input_fields_by_id[ref.id].dtype,
1604+
input_fields_by_id[ref.id].nullable,
1605+
)
1606+
for ref, output in self.output_cols
1607+
)
1608+
16451609
@property
16461610
def consumed_ids(self) -> COLUMN_SET:
16471611
out_refs = frozenset(ref.id for ref, _ in self.output_cols)

bigframes/core/rewrite/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from bigframes.core.rewrite.fold_row_count import fold_row_counts
1516
from bigframes.core.rewrite.identifiers import remap_variables
1617
from bigframes.core.rewrite.implicit_align import try_row_join
1718
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
@@ -38,4 +39,5 @@
3839
"try_reduce_to_table_scan",
3940
"bake_order",
4041
"try_reduce_to_local_scan",
42+
"fold_row_counts",
4143
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import pyarrow as pa
17+
18+
from bigframes import dtypes
19+
from bigframes.core import local_data, nodes
20+
from bigframes.operations import aggregations
21+
22+
23+
def fold_row_counts(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
24+
if not isinstance(node, nodes.AggregateNode):
25+
return node
26+
if len(node.by_column_ids) > 0:
27+
return node
28+
if node.child.row_count is None:
29+
return node
30+
for agg, _ in node.aggregations:
31+
if agg.op != aggregations.size_op:
32+
return node
33+
local_data_source = local_data.ManagedArrowTable.from_pyarrow(
34+
pa.table({"count": pa.array([node.child.row_count], type=pa.int64())})
35+
)
36+
scan_list = nodes.ScanList(
37+
tuple(
38+
nodes.ScanItem(out_id, dtypes.INT_DTYPE, "count")
39+
for _, out_id in node.aggregations
40+
)
41+
)
42+
return nodes.ReadLocalNode(
43+
local_data_source=local_data_source, scan_list=scan_list, session=node.session
44+
)

bigframes/core/rewrite/order.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,6 @@ def pull_up_order_inner(
211211
)
212212
new_order = child_order.remap_column_refs(new_select_node.get_id_mapping())
213213
return new_select_node, new_order
214-
elif isinstance(node, bigframes.core.nodes.RowCountNode):
215-
child_result = remove_order(node.child)
216-
return node.replace_child(
217-
child_result
218-
), bigframes.core.ordering.TotalOrdering.from_primary_key([node.col_id])
219214
elif isinstance(node, bigframes.core.nodes.AggregateNode):
220215
if node.has_ordered_ops:
221216
child_result, child_order = pull_up_order_inner(node.child)

bigframes/core/rewrite/pruning.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,17 @@ def prune_columns(node: nodes.BigFrameNode):
5151
if isinstance(node, nodes.SelectionNode):
5252
result = prune_selection_child(node)
5353
elif isinstance(node, nodes.ResultNode):
54-
result = node.replace_child(prune_node(node.child, node.consumed_ids))
54+
result = node.replace_child(
55+
prune_node(
56+
node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1])
57+
)
58+
)
5559
elif isinstance(node, nodes.AggregateNode):
56-
result = node.replace_child(prune_node(node.child, node.consumed_ids))
60+
result = node.replace_child(
61+
prune_node(
62+
node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1])
63+
)
64+
)
5765
elif isinstance(node, nodes.InNode):
5866
result = dataclasses.replace(
5967
node,
@@ -71,7 +79,9 @@ def prune_selection_child(
7179

7280
# Important to check this first
7381
if list(selection.ids) == list(child.ids):
74-
return child
82+
if (ref.ref.id == ref.id for ref in selection.input_output_pairs):
83+
# selection is no-op so just remove it entirely
84+
return child
7585

7686
if isinstance(child, nodes.SelectionNode):
7787
return selection.remap_refs(
@@ -96,6 +106,9 @@ def prune_selection_child(
96106
indices = [
97107
list(child.ids).index(ref.id) for ref, _ in selection.input_output_pairs
98108
]
109+
if len(indices) == 0:
110+
# pushing zero-column selection into concat messes up emitter for now, which doesn't like zero columns
111+
return selection
99112
new_children = []
100113
for concat_node in child.child_nodes:
101114
cc_ids = tuple(concat_node.ids)
@@ -146,7 +159,10 @@ def prune_aggregate(
146159
node: nodes.AggregateNode,
147160
used_cols: AbstractSet[identifiers.ColumnId],
148161
) -> nodes.AggregateNode:
149-
pruned_aggs = tuple(agg for agg in node.aggregations if agg[1] in used_cols)
162+
pruned_aggs = (
163+
tuple(agg for agg in node.aggregations if agg[1] in used_cols)
164+
or node.aggregations[0:1]
165+
)
150166
return dataclasses.replace(node, aggregations=pruned_aggs)
151167

152168

0 commit comments

Comments
 (0)