Skip to content

Commit 83aeb4f

Browse files
perf: Fold row count ops when known
1 parent 6199023 commit 83aeb4f

File tree

12 files changed

+103
-125
lines changed

12 files changed

+103
-125
lines changed

bigframes/core/array_value.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,17 @@ def get_column_type(self, key: str) -> bigframes.dtypes.Dtype:
204204

205205
def row_count(self) -> ArrayValue:
206206
"""Get number of rows in ArrayValue as a single-entry ArrayValue."""
207-
return ArrayValue(nodes.RowCountNode(child=self.node))
207+
return ArrayValue(
208+
nodes.AggregateNode(
209+
child=self.node,
210+
aggregations=(
211+
(
212+
ex.NullaryAggregation(agg_ops.size_op),
213+
ids.ColumnId(bigframes.core.guid.generate_guid()),
214+
),
215+
),
216+
)
217+
)
208218

209219
# Operations
210220
def filter_by_id(self, predicate_id: str, keep_null: bool = False) -> ArrayValue:

bigframes/core/blocks.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,14 @@ def shape(self) -> typing.Tuple[int, int]:
221221
except Exception:
222222
pass
223223

224-
row_count = self.session._executor.get_row_count(self.expr)
224+
row_count = next(
225+
iter(
226+
self.session._executor.execute(self.expr.row_count())
227+
.to_arrow_table()
228+
.to_pydict()
229+
.values()
230+
)
231+
)[0]
225232
return (row_count, len(self.value_columns))
226233

227234
@property
@@ -485,7 +492,7 @@ def to_arrow(
485492
*,
486493
ordered: bool = True,
487494
allow_large_results: Optional[bool] = None,
488-
) -> Tuple[pa.Table, bigquery.QueryJob]:
495+
) -> Tuple[pa.Table, Optional[bigquery.QueryJob]]:
489496
"""Run query and download results as a pyarrow Table."""
490497
execute_result = self.session._executor.execute(
491498
self.expr, ordered=ordered, use_explicit_destination=allow_large_results
@@ -659,7 +666,7 @@ def _materialize_local(
659666

660667
# TODO: Maybe materialize before downsampling
661668
# Some downsampling methods
662-
if fraction < 1:
669+
if fraction < 1 and (execute_result.total_rows is not None):
663670
if not sample_config.enable_downsampling:
664671
raise RuntimeError(
665672
f"The data size ({table_mb:.2f} MB) exceeds the maximum download limit of "
@@ -690,7 +697,6 @@ def _materialize_local(
690697
MaterializationOptions(ordered=materialize_options.ordered)
691698
)
692699
else:
693-
total_rows = execute_result.total_rows
694700
arrow = execute_result.to_arrow_table()
695701
df = io_pandas.arrow_to_pandas(arrow, schema=self.expr.schema)
696702
self._copy_index_to_pandas(df)
@@ -1570,12 +1576,19 @@ def retrieve_repr_request_results(
15701576

15711577
# head caches full underlying expression, so row_count will be free after
15721578
head_result = self.session._executor.head(self.expr, max_results)
1573-
count = self.session._executor.get_row_count(self.expr)
1579+
row_count = next(
1580+
iter(
1581+
self.session._executor.execute(self.expr.row_count())
1582+
.to_arrow_table()
1583+
.to_pydict()
1584+
.values()
1585+
)
1586+
)[0]
15741587

15751588
arrow = head_result.to_arrow_table()
15761589
df = io_pandas.arrow_to_pandas(arrow, schema=self.expr.schema)
15771590
self._copy_index_to_pandas(df)
1578-
return df, count, head_result.query_job
1591+
return df, row_count, head_result.query_job
15791592

15801593
def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]:
15811594
expr, result_id = self._expr.promote_offsets()

bigframes/core/compile/compiler.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,6 @@ def compile_concat(node: nodes.ConcatNode, *children: compiled.UnorderedIR):
270270
return concat_impl.concat_unordered(children, output_ids)
271271

272272

273-
@_compile_node.register
274-
def compile_rowcount(node: nodes.RowCountNode, child: compiled.UnorderedIR):
275-
return child.row_count(name=node.col_id.sql)
276-
277-
278273
@_compile_node.register
279274
def compile_aggregate(node: nodes.AggregateNode, child: compiled.UnorderedIR):
280275
aggs = tuple((agg, id.sql) for agg, id in node.aggregations)

bigframes/core/compile/polars/compiler.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,6 @@ def compile_projection(self, node: nodes.ProjectionNode):
252252
]
253253
return self.compile_node(node.child).with_columns(new_cols)
254254

255-
@compile_node.register
256-
def compile_rowcount(self, node: nodes.RowCountNode):
257-
df = cast(pl.LazyFrame, self.compile_node(node.child))
258-
return df.select(pl.len().alias(node.col_id.sql))
259-
260255
@compile_node.register
261256
def compile_offsets(self, node: nodes.PromoteOffsetsNode):
262257
return self.compile_node(node.child).with_columns(

bigframes/core/nodes.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,55 +1256,6 @@ def remap_refs(
12561256
return dataclasses.replace(self, assignments=new_fields)
12571257

12581258

1259-
# TODO: Merge RowCount into Aggregate Node?
1260-
# Row count can be compute from table metadata sometimes, so it is a bit special.
1261-
@dataclasses.dataclass(frozen=True, eq=False)
1262-
class RowCountNode(UnaryNode):
1263-
col_id: identifiers.ColumnId = identifiers.ColumnId("count")
1264-
1265-
@property
1266-
def row_preserving(self) -> bool:
1267-
return False
1268-
1269-
@property
1270-
def non_local(self) -> bool:
1271-
return True
1272-
1273-
@property
1274-
def fields(self) -> Sequence[Field]:
1275-
return (Field(self.col_id, bigframes.dtypes.INT_DTYPE, nullable=False),)
1276-
1277-
@property
1278-
def variables_introduced(self) -> int:
1279-
return 1
1280-
1281-
@property
1282-
def defines_namespace(self) -> bool:
1283-
return True
1284-
1285-
@property
1286-
def row_count(self) -> Optional[int]:
1287-
return 1
1288-
1289-
@property
1290-
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
1291-
return (self.col_id,)
1292-
1293-
@property
1294-
def consumed_ids(self) -> COLUMN_SET:
1295-
return frozenset()
1296-
1297-
def remap_vars(
1298-
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1299-
) -> RowCountNode:
1300-
return dataclasses.replace(self, col_id=mappings.get(self.col_id, self.col_id))
1301-
1302-
def remap_refs(
1303-
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1304-
) -> RowCountNode:
1305-
return self
1306-
1307-
13081259
@dataclasses.dataclass(frozen=True, eq=False)
13091260
class AggregateNode(UnaryNode):
13101261
aggregations: typing.Tuple[typing.Tuple[ex.Aggregation, identifiers.ColumnId], ...]

bigframes/core/rewrite/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from bigframes.core.rewrite.fold_row_count import fold_row_counts
1516
from bigframes.core.rewrite.identifiers import remap_variables
1617
from bigframes.core.rewrite.implicit_align import try_row_join
1718
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
@@ -38,4 +39,5 @@
3839
"try_reduce_to_table_scan",
3940
"bake_order",
4041
"try_reduce_to_local_scan",
42+
"fold_row_counts",
4143
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import pyarrow as pa
17+
18+
from bigframes import dtypes
19+
from bigframes.core import local_data, nodes
20+
from bigframes.operations import aggregations
21+
22+
23+
def fold_row_counts(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
24+
if not isinstance(node, nodes.AggregateNode):
25+
return node
26+
if len(node.by_column_ids) > 0:
27+
return node
28+
if node.child.row_count is None:
29+
return node
30+
for agg, _ in node.aggregations:
31+
if agg.op != aggregations.size_op:
32+
return node
33+
local_data_source = local_data.ManagedArrowTable.from_pyarrow(
34+
pa.table({"count": pa.array([node.child.row_count], type=pa.int64())})
35+
)
36+
scan_list = nodes.ScanList(
37+
tuple(
38+
nodes.ScanItem(out_id, dtypes.INT_DTYPE, "count")
39+
for _, out_id in node.aggregations
40+
)
41+
)
42+
return nodes.ReadLocalNode(
43+
local_data_source=local_data_source, scan_list=scan_list, session=node.session
44+
)

bigframes/core/rewrite/order.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,6 @@ def pull_up_order_inner(
211211
)
212212
new_order = child_order.remap_column_refs(new_select_node.get_id_mapping())
213213
return new_select_node, new_order
214-
elif isinstance(node, bigframes.core.nodes.RowCountNode):
215-
child_result = remove_order(node.child)
216-
return node.replace_child(
217-
child_result
218-
), bigframes.core.ordering.TotalOrdering.from_primary_key([node.col_id])
219214
elif isinstance(node, bigframes.core.nodes.AggregateNode):
220215
if node.has_ordered_ops:
221216
child_result, child_order = pull_up_order_inner(node.child)

0 commit comments

Comments
 (0)