Skip to content

perf: Fold row count ops when known #1656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 29, 2025
12 changes: 11 additions & 1 deletion bigframes/core/array_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,17 @@ def get_column_type(self, key: str) -> bigframes.dtypes.Dtype:

def row_count(self) -> ArrayValue:
"""Get number of rows in ArrayValue as a single-entry ArrayValue."""
return ArrayValue(nodes.RowCountNode(child=self.node))
return ArrayValue(
nodes.AggregateNode(
child=self.node,
aggregations=(
(
ex.NullaryAggregation(agg_ops.size_op),
ids.ColumnId(bigframes.core.guid.generate_guid()),
),
),
)
)

# Operations
def filter_by_id(self, predicate_id: str, keep_null: bool = False) -> ArrayValue:
Expand Down
28 changes: 9 additions & 19 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)
import warnings
Expand Down Expand Up @@ -68,13 +67,8 @@
import bigframes.core.window_spec as windows
import bigframes.dtypes
import bigframes.exceptions as bfe
import bigframes.features
import bigframes.operations as ops
import bigframes.operations.aggregations as agg_ops
import bigframes.session._io.pandas as io_pandas

if TYPE_CHECKING:
import bigframes.session.executor

# Type constraint for wherever column labels are used
Label = typing.Hashable
Expand Down Expand Up @@ -221,7 +215,7 @@ def shape(self) -> typing.Tuple[int, int]:
except Exception:
pass

row_count = self.session._executor.get_row_count(self.expr)
row_count = self.session._executor.execute(self.expr.row_count()).to_py_scalar()
return (row_count, len(self.value_columns))

@property
Expand Down Expand Up @@ -485,7 +479,7 @@ def to_arrow(
*,
ordered: bool = True,
allow_large_results: Optional[bool] = None,
) -> Tuple[pa.Table, bigquery.QueryJob]:
) -> Tuple[pa.Table, Optional[bigquery.QueryJob]]:
"""Run query and download results as a pyarrow Table."""
execute_result = self.session._executor.execute(
self.expr, ordered=ordered, use_explicit_destination=allow_large_results
Expand Down Expand Up @@ -580,7 +574,7 @@ def try_peek(
result = self.session._executor.peek(
self.expr, n, use_explicit_destination=allow_large_results
)
df = io_pandas.arrow_to_pandas(result.to_arrow_table(), self.expr.schema)
df = result.to_pandas()
self._copy_index_to_pandas(df)
return df
else:
Expand All @@ -604,8 +598,7 @@ def to_pandas_batches(
page_size=page_size,
max_results=max_results,
)
for record_batch in execute_result.arrow_batches():
df = io_pandas.arrow_to_pandas(record_batch, self.expr.schema)
for df in execute_result.to_pandas_batches():
self._copy_index_to_pandas(df)
if squeeze:
yield df.squeeze(axis=1)
Expand Down Expand Up @@ -659,7 +652,7 @@ def _materialize_local(

# TODO: Maybe materialize before downsampling
# Some downsampling methods
if fraction < 1:
if fraction < 1 and (execute_result.total_rows is not None):
if not sample_config.enable_downsampling:
raise RuntimeError(
f"The data size ({table_mb:.2f} MB) exceeds the maximum download limit of "
Expand Down Expand Up @@ -690,9 +683,7 @@ def _materialize_local(
MaterializationOptions(ordered=materialize_options.ordered)
)
else:
total_rows = execute_result.total_rows
arrow = execute_result.to_arrow_table()
df = io_pandas.arrow_to_pandas(arrow, schema=self.expr.schema)
df = execute_result.to_pandas()
self._copy_index_to_pandas(df)

return df, execute_result.query_job
Expand Down Expand Up @@ -1570,12 +1561,11 @@ def retrieve_repr_request_results(

# head caches full underlying expression, so row_count will be free after
head_result = self.session._executor.head(self.expr, max_results)
count = self.session._executor.get_row_count(self.expr)
row_count = self.session._executor.execute(self.expr.row_count()).to_py_scalar()

arrow = head_result.to_arrow_table()
df = io_pandas.arrow_to_pandas(arrow, schema=self.expr.schema)
df = head_result.to_pandas()
self._copy_index_to_pandas(df)
return df, count, head_result.query_job
return df, row_count, head_result.query_job

def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]:
expr, result_id = self._expr.promote_offsets()
Expand Down
9 changes: 1 addition & 8 deletions bigframes/core/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def compile_readlocal(node: nodes.ReadLocalNode, *args):
bq_schema = node.schema.to_bigquery()

pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
pa_table = pa_table.rename_columns(
{item.source_id: item.id.sql for item in node.scan_list.items}
)
pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items])

if offsets:
pa_table = pa_table.append_column(
Expand Down Expand Up @@ -254,11 +252,6 @@ def compile_concat(node: nodes.ConcatNode, *children: compiled.UnorderedIR):
return concat_impl.concat_unordered(children, output_ids)


@_compile_node.register
def compile_rowcount(node: nodes.RowCountNode, child: compiled.UnorderedIR):
return child.row_count(name=node.col_id.sql)


@_compile_node.register
def compile_aggregate(node: nodes.AggregateNode, child: compiled.UnorderedIR):
aggs = tuple((agg, id.sql) for agg, id in node.aggregations)
Expand Down
5 changes: 0 additions & 5 deletions bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,6 @@ def compile_projection(self, node: nodes.ProjectionNode):
]
return self.compile_node(node.child).with_columns(new_cols)

@compile_node.register
def compile_rowcount(self, node: nodes.RowCountNode):
df = cast(pl.LazyFrame, self.compile_node(node.child))
return df.select(pl.len().alias(node.col_id.sql))

@compile_node.register
def compile_offsets(self, node: nodes.PromoteOffsetsNode):
return self.compile_node(node.child).with_columns(
Expand Down
62 changes: 13 additions & 49 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,55 +1256,6 @@ def remap_refs(
return dataclasses.replace(self, assignments=new_fields)


# TODO: Merge RowCount into Aggregate Node?
# Row count can be compute from table metadata sometimes, so it is a bit special.
@dataclasses.dataclass(frozen=True, eq=False)
class RowCountNode(UnaryNode):
col_id: identifiers.ColumnId = identifiers.ColumnId("count")

@property
def row_preserving(self) -> bool:
return False

@property
def non_local(self) -> bool:
return True

@property
def fields(self) -> Sequence[Field]:
return (Field(self.col_id, bigframes.dtypes.INT_DTYPE, nullable=False),)

@property
def variables_introduced(self) -> int:
return 1

@property
def defines_namespace(self) -> bool:
return True

@property
def row_count(self) -> Optional[int]:
return 1

@property
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
return (self.col_id,)

@property
def consumed_ids(self) -> COLUMN_SET:
return frozenset()

def remap_vars(
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
) -> RowCountNode:
return dataclasses.replace(self, col_id=mappings.get(self.col_id, self.col_id))

def remap_refs(
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
) -> RowCountNode:
return self


@dataclasses.dataclass(frozen=True, eq=False)
class AggregateNode(UnaryNode):
aggregations: typing.Tuple[typing.Tuple[ex.Aggregation, identifiers.ColumnId], ...]
Expand Down Expand Up @@ -1642,6 +1593,19 @@ def remap_refs(
order_by = self.order_by.remap_column_refs(mappings) if self.order_by else None
return dataclasses.replace(self, output_cols=output_cols, order_by=order_by) # type: ignore

@property
def fields(self) -> Sequence[Field]:
# Fields property here is for output schema, not to be consumed by a parent node.
input_fields_by_id = {field.id: field for field in self.child.fields}
return tuple(
Field(
identifiers.ColumnId(output),
input_fields_by_id[ref.id].dtype,
input_fields_by_id[ref.id].nullable,
)
for ref, output in self.output_cols
)

@property
def consumed_ids(self) -> COLUMN_SET:
out_refs = frozenset(ref.id for ref, _ in self.output_cols)
Expand Down
2 changes: 2 additions & 0 deletions bigframes/core/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from bigframes.core.rewrite.fold_row_count import fold_row_counts
from bigframes.core.rewrite.identifiers import remap_variables
from bigframes.core.rewrite.implicit_align import try_row_join
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
Expand All @@ -38,4 +39,5 @@
"try_reduce_to_table_scan",
"bake_order",
"try_reduce_to_local_scan",
"fold_row_counts",
]
44 changes: 44 additions & 0 deletions bigframes/core/rewrite/fold_row_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import pyarrow as pa

from bigframes import dtypes
from bigframes.core import local_data, nodes
from bigframes.operations import aggregations


def fold_row_counts(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A naming suggestion for this module: I think we can name the file as "row_counts"

The good thing is that if you want to add more re-write functions for row_counts in the future, they will all look like this row_counts.expand(..), row_counts.reverse(...), which keeps the code organized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fold_row_count isn't the module though, core.compile.rewrite is. For now this file just defines a single symbol, so figure best to be as precise as possible? Easy enough to change later if add more to this file.

if not isinstance(node, nodes.AggregateNode):
return node
if len(node.by_column_ids) > 0:
return node
if node.child.row_count is None:
return node
for agg, _ in node.aggregations:
if agg.op != aggregations.size_op:
return node
local_data_source = local_data.ManagedArrowTable.from_pyarrow(
pa.table({"count": pa.array([node.child.row_count], type=pa.int64())})
)
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(out_id, dtypes.INT_DTYPE, "count")
for _, out_id in node.aggregations
)
)
return nodes.ReadLocalNode(
local_data_source=local_data_source, scan_list=scan_list, session=node.session
)
5 changes: 0 additions & 5 deletions bigframes/core/rewrite/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,6 @@ def pull_up_order_inner(
)
new_order = child_order.remap_column_refs(new_select_node.get_id_mapping())
return new_select_node, new_order
elif isinstance(node, bigframes.core.nodes.RowCountNode):
child_result = remove_order(node.child)
return node.replace_child(
child_result
), bigframes.core.ordering.TotalOrdering.from_primary_key([node.col_id])
elif isinstance(node, bigframes.core.nodes.AggregateNode):
if node.has_ordered_ops:
child_result, child_order = pull_up_order_inner(node.child)
Expand Down
24 changes: 20 additions & 4 deletions bigframes/core/rewrite/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,17 @@ def prune_columns(node: nodes.BigFrameNode):
if isinstance(node, nodes.SelectionNode):
result = prune_selection_child(node)
elif isinstance(node, nodes.ResultNode):
result = node.replace_child(prune_node(node.child, node.consumed_ids))
result = node.replace_child(
prune_node(
node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1])
)
)
elif isinstance(node, nodes.AggregateNode):
result = node.replace_child(prune_node(node.child, node.consumed_ids))
result = node.replace_child(
prune_node(
node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1])
)
)
elif isinstance(node, nodes.InNode):
result = dataclasses.replace(
node,
Expand All @@ -71,7 +79,9 @@ def prune_selection_child(

# Important to check this first
if list(selection.ids) == list(child.ids):
return child
if (ref.ref.id == ref.id for ref in selection.input_output_pairs):
# selection is no-op so just remove it entirely
return child

if isinstance(child, nodes.SelectionNode):
return selection.remap_refs(
Expand All @@ -96,6 +106,9 @@ def prune_selection_child(
indices = [
list(child.ids).index(ref.id) for ref, _ in selection.input_output_pairs
]
if len(indices) == 0:
# pushing zero-column selection into concat messes up emitter for now, which doesn't like zero columns
return selection
new_children = []
for concat_node in child.child_nodes:
cc_ids = tuple(concat_node.ids)
Expand Down Expand Up @@ -146,7 +159,10 @@ def prune_aggregate(
node: nodes.AggregateNode,
used_cols: AbstractSet[identifiers.ColumnId],
) -> nodes.AggregateNode:
pruned_aggs = tuple(agg for agg in node.aggregations if agg[1] in used_cols)
pruned_aggs = (
tuple(agg for agg in node.aggregations if agg[1] in used_cols)
or node.aggregations[0:1]
)
return dataclasses.replace(node, aggregations=pruned_aggs)


Expand Down
Loading