Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
4d33b68
feat(expr-ir): Getting started on `GroupBy`
dangotbanned Sep 15, 2025
3718690
feat(DRAFT): mock up `resolve_group_by`
dangotbanned Sep 15, 2025
f70c021
fix: re-sync `GroupByKeys`
dangotbanned Sep 16, 2025
3828ea4
feat: Make `rewrite_projections(keys)` optional
dangotbanned Sep 16, 2025
feb7661
feat: Add `FrozenSchema.merge`
dangotbanned Sep 16, 2025
561909c
chore: more informative placeholder error
dangotbanned Sep 16, 2025
7a811b6
feat(DRAFT): Start spec-ing `CompliantGroupBy`
dangotbanned Sep 18, 2025
6d3c0a9
feat(DRAFT): Implement some of `ArrowGroupBy`
dangotbanned Sep 18, 2025
e71d092
feat(DRAFT): Fill out more of `GroupBy.agg`
dangotbanned Sep 18, 2025
9179ec3
Merge branch 'oh-nodes' into expr-ir/group-by
dangotbanned Sep 18, 2025
8aaf9a9
fix: avoid `typing_extensions` import
dangotbanned Sep 18, 2025
d9b918f
refactor: Move `ArrowGroupBy`
dangotbanned Sep 18, 2025
767261c
feat(DRAFT): Simple cases working?
dangotbanned Sep 19, 2025
648d5d9
feat(expr-ir): Add missing `Expr.len`
dangotbanned Sep 19, 2025
2682b10
feat(expr-ir): Support `nw.len()`
dangotbanned Sep 19, 2025
e1c3145
feat(expr-ir): support auto-implode
dangotbanned Sep 19, 2025
5d36607
feat(DRAFT): Support `nw.col("a").unique()` in `group_by`
dangotbanned Sep 19, 2025
1aa2464
test: Port over `tests/frame/group_by_test`
dangotbanned Sep 19, 2025
45a816f
cov
dangotbanned Sep 19, 2025
8f2ad50
chore: Update todo
dangotbanned Sep 19, 2025
4650456
chore: Add todo for `drop_null_keys=True`
dangotbanned Sep 19, 2025
4a52cec
feat(DRAFT): start custom `pa.TableGroupBy` impl
dangotbanned Sep 20, 2025
ce18f51
fix: Avoid shadowed output aggregation names
dangotbanned Sep 20, 2025
16148b2
feat(expr-ir): Rewrite, fix ordered aggregations
dangotbanned Sep 20, 2025
ce86f8f
test: Port over `first`, `last` group_by tests
dangotbanned Sep 20, 2025
581e511
test: Add failing `drop_null_keys`, `__iter__` tests
dangotbanned Sep 20, 2025
4b77500
feat(expr-ir): Support `group_by(drop_null_keys=True)`
dangotbanned Sep 20, 2025
0a94af6
fix: avoid `typing_extensions` import (again)
dangotbanned Sep 20, 2025
9cb5a86
Merge branch 'oh-nodes' into expr-ir/group-by
dangotbanned Sep 21, 2025
874e736
feat(expr-ir): Reject `drop_null_keys` with `Expr`
dangotbanned Sep 21, 2025
42ccd14
refactor: Return `NamedIR` from `prepare_projection`
dangotbanned Sep 21, 2025
451498c
feat(expr-ir): *Almost* all `Expr` key tests passing!
dangotbanned Sep 21, 2025
780af66
fix(DRAFT): Roughly port over `ParseExprKeysGroupBy`
dangotbanned Sep 22, 2025
e0debe5
refactor: Slightly simplify temp name
dangotbanned Sep 22, 2025
8b622d4
feat: Add temp column name utils
dangotbanned Sep 22, 2025
cad1a47
refactor: Replace temp naming stuff
dangotbanned Sep 22, 2025
57c3e6d
test: Add `Expr.unique` group_by tests
dangotbanned Sep 22, 2025
87cd4a8
fix: Use `operator.or_` instead of `pyarrow.compute.or_`
dangotbanned Sep 23, 2025
d49bcce
test: Steal some of the `polars` test suite 😉
dangotbanned Sep 23, 2025
9d72311
test: `df.group_by(**named_by)`
dangotbanned Sep 23, 2025
479eee6
revert: Don't introduce unused type var
dangotbanned Sep 23, 2025
4391a6f
chore: Remove completed todo
dangotbanned Sep 23, 2025
668db86
docs: Trim `Schema.merge`
dangotbanned Sep 23, 2025
ad4babd
chore: Clean up unused in `arrow.group_by`
dangotbanned Sep 23, 2025
a940e05
test: Add `test_group_by_exclude_keys`
dangotbanned Sep 23, 2025
3a0617d
refactor: Tweak `prepare_excluded`
dangotbanned Sep 23, 2025
28601df
refactor: Refining schema projections
dangotbanned Sep 23, 2025
6122245
refactor: Clean up `resolve_group_by` a bit
dangotbanned Sep 23, 2025
8d1220d
perf: Skip synthesized `FrozenSchema.__init__`
dangotbanned Sep 25, 2025
6dcfa4f
feat: Accept more in `freeze_schema`, `IntoFrozenSchema`, `prepare_pr…
dangotbanned Sep 25, 2025
d7cf2d6
refactor(DRAFT): Add `Grouper`/`Resolver` concepts
dangotbanned Sep 25, 2025
3d0670b
refactor: Move loads of stuff up from `arrow`
dangotbanned Sep 25, 2025
a234cc7
😠😠😠
dangotbanned Sep 25, 2025
9d17006
refactor: Define `group_by_agg`
dangotbanned Sep 26, 2025
8eb5db0
feat(DRAFT): Almost direct port of `ArrowGroupBy.__iter__`
dangotbanned Sep 26, 2025
890732e
test: Add (failing) `test_group_by_expr_iter`
dangotbanned Sep 26, 2025
3776e3a
fix: Select the right columns in `__iter__`
dangotbanned Sep 26, 2025
5b1fa00
chore: Tidy up some comments/notes/docs
dangotbanned Sep 26, 2025
0fac622
excessive comments
dangotbanned Sep 26, 2025
c549736
perf: Use `pc.Expression` instead of eager predicate
dangotbanned Sep 26, 2025
5ef5c53
perf: `remove_column` instead of `drop`
dangotbanned Sep 26, 2025
70523ff
feat: Add `arrow.acero` module
dangotbanned Sep 27, 2025
e148266
refactor: Use a single options class
dangotbanned Sep 27, 2025
1f389df
refactor: renaming/aliasing
dangotbanned Sep 27, 2025
9ba0038
refactor: Split out, rewrite composite key concat
dangotbanned Sep 27, 2025
42872e8
refactor: Use `unique` method
dangotbanned Sep 27, 2025
d098acc
refactor: Remove aliasing that doesn't save lines
dangotbanned Sep 27, 2025
113b6a5
typo
dangotbanned Sep 27, 2025
b35be60
perf: Remove unnecessary `remove_column`
dangotbanned Sep 27, 2025
23565a3
perf: Cached, lazy-loaded options
dangotbanned Sep 27, 2025
7099e04
fix: Simplify, fix, optimize `ArrowDataFrame.row`
dangotbanned Sep 27, 2025
55b1caf
refactor: Clean up more of `__iter__`
dangotbanned Sep 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions narwhals/_plan/_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,10 @@
Excluded: TypeAlias = "frozenset[str]"
"""Internally use a `set`, then freeze before returning."""

GroupByKeys: TypeAlias = "Seq[ExprIR]"
"""Represents group_by keys.
GroupByKeys: TypeAlias = "Seq[str]"
"""Represents `group_by` keys.

- Originates from `polars_plan::plans::conversion::dsl_to_ir::resolve_group_by`
- Not fully utilized in `narwhals` version yet
They need to be excluded from expansion.
"""

OutputNames: TypeAlias = "Seq[str]"
Expand Down Expand Up @@ -154,24 +153,23 @@ def with_multiple_columns(self) -> ExpansionFlags:


def prepare_projection(
exprs: Sequence[ExprIR], schema: IntoFrozenSchema
) -> tuple[Seq[ExprIR], FrozenSchema, OutputNames]:
exprs: Sequence[ExprIR], /, keys: GroupByKeys = (), *, schema: IntoFrozenSchema
) -> tuple[Seq[NamedIR], FrozenSchema]:
"""Expand IRs into named column selections.

**Primary entry-point**, will be used by `select`, `with_columns`,
**Primary entry-point**, for `select`, `with_columns`,
and any other context that requires resolving expression names.

Arguments:
exprs: IRs that *may* contain things like `Columns`, `SelectorIR`, `Exclude`, etc.
keys: Names of `group_by` columns.
schema: Scope to expand multi-column selectors in.

Returns:
`exprs`, rewritten using `Column(name)` only.
"""
frozen_schema = freeze_schema(schema)
rewritten = rewrite_projections(tuple(exprs), keys=(), schema=frozen_schema)
rewritten = rewrite_projections(tuple(exprs), keys=keys, schema=frozen_schema)
output_names = ensure_valid_exprs(rewritten, frozen_schema)
return rewritten, frozen_schema, output_names
named_irs = into_named_irs(rewritten, output_names)
return named_irs, frozen_schema


def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]:
Expand Down Expand Up @@ -202,7 +200,7 @@ def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames:
def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR:
def fn(child: ExprIR, /) -> ExprIR:
if is_horizontal_reduction(child):
rewrites = rewrite_projections(child.input, keys=(), schema=schema)
rewrites = rewrite_projections(child.input, schema=schema)
return common.replace(child, input=rewrites)
return child

Expand Down Expand Up @@ -275,7 +273,7 @@ def expand_selector(selector: SelectorIR, schema: FrozenSchema) -> Columns:
def rewrite_projections(
input: Seq[ExprIR], # `FunctionExpr.input`
/,
keys: GroupByKeys,
keys: GroupByKeys = (),
*,
schema: FrozenSchema,
) -> Seq[ExprIR]:
Expand Down Expand Up @@ -323,13 +321,10 @@ def prepare_excluded(
origin: ExprIR, keys: GroupByKeys, flags: ExpansionFlags, /
) -> Excluded:
"""Huge simplification of https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L484-L555."""
exclude: set[str] = set()
if flags.has_exclude:
exclude.update(*(e.names for e in origin.iter_left() if isinstance(e, Exclude)))
for group_by_key in keys:
if name := group_by_key.meta.output_name(raise_if_undetermined=False):
exclude.add(name)
return frozenset(exclude)
gb_keys = frozenset(keys)
if not flags.has_exclude:
return gb_keys
return gb_keys.union(*(e.names for e in origin.iter_left() if isinstance(e, Exclude)))


def _all_columns_match(origin: ExprIR, /, columns: Columns) -> bool:
Expand Down
14 changes: 14 additions & 0 deletions narwhals/_plan/_expr_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,17 @@ def is_elementwise_top_level(self) -> bool:
if is_literal(ir):
return ir.is_scalar
return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast))

def is_column(self, *, allow_aliasing: bool = False) -> bool:
"""Return True if wrapping a single `Column` node.

Note:
Multi-output (including selectors) expressions have been expanded at this stage.

Arguments:
allow_aliasing: If False (default), any aliasing is not considered to be column selection.
"""
from narwhals._plan.expressions import Column

ir = self.expr
return isinstance(ir, Column) and ((self.name == ir.name) or allow_aliasing)
5 changes: 2 additions & 3 deletions narwhals/_plan/_rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import TYPE_CHECKING

from narwhals._plan._expansion import into_named_irs, prepare_projection
from narwhals._plan._expansion import prepare_projection
from narwhals._plan._guards import (
is_aggregation,
is_binary_expr,
Expand All @@ -31,8 +31,7 @@ def rewrite_all(
- Currently we do a full traversal of each tree per-rewrite function
- There's no caching *after* `prepare_projection` yet
"""
out_irs, _, names = prepare_projection(parse_into_seq_of_expr_ir(*exprs), schema)
named_irs = into_named_irs(out_irs, names)
named_irs, _ = prepare_projection(parse_into_seq_of_expr_ir(*exprs), schema=schema)
return tuple(map_ir(ir, *rewrites) for ir in named_irs)


Expand Down
197 changes: 197 additions & 0 deletions narwhals/_plan/arrow/acero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""Sugar for working with [Acero].

[`pyarrow.acero`] has some building blocks for constructing queries, but is
quite verbose when used directly.

This module aligns some apis to look more like `polars`.

[Acero]: https://arrow.apache.org/docs/cpp/acero/overview.html
[`pyarrow.acero`]: https://arrow.apache.org/docs/python/api/acero.html
"""

from __future__ import annotations

import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union

import pyarrow as pa # ignore-banned-import
import pyarrow.acero as pac
import pyarrow.compute as pc # ignore-banned-import
from pyarrow.acero import Declaration as Decl

from narwhals.typing import SingleColSelector

if TYPE_CHECKING:
from collections.abc import Iterable

from typing_extensions import TypeAlias

from narwhals._arrow.typing import ( # type: ignore[attr-defined]
AggregateOptions as _AggregateOptions,
Aggregation as _Aggregation,
)
from narwhals._plan.typing import Seq
from narwhals.typing import NonNestedLiteral

T = TypeVar("T")
OneOrListOrTuple: TypeAlias = Union[T, list[T], tuple[T, ...]]
"""WARNING: Don't use this unless there is a runtime check for exactly `list | tuple`."""


Incomplete: TypeAlias = Any
Expr: TypeAlias = pc.Expression
IntoExpr: TypeAlias = "Expr | NonNestedLiteral"
Field: TypeAlias = Union[Expr, SingleColSelector]
"""Anything that passes as a single item in [`_compute._ensure_field_ref`].

[`_compute._ensure_field_ref`]: https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_compute.pyx#L1507-L1531
"""

AggKeys: TypeAlias = "Iterable[Field] | None"

Target: TypeAlias = OneOrListOrTuple[Field]
Aggregation: TypeAlias = "_Aggregation"
AggregateOptions: TypeAlias = "_AggregateOptions"
Opts: TypeAlias = "AggregateOptions | None"
OutputName: TypeAlias = str
AggSpec: TypeAlias = tuple[Target, Aggregation, Opts, OutputName]


# TODO @dangotbanned: Rename
def pc_expr(into: IntoExpr, /, *, str_as_lit: bool = False) -> Expr:
if isinstance(into, pc.Expression):
return into
if isinstance(into, str) and not str_as_lit:
return pc.field(into)
arg: Incomplete = into
return pc.scalar(arg)


def _parse_all_horizontal(predicates: Seq[Expr], constraints: dict[str, Any], /) -> Expr:
if not constraints and len(predicates) == 1:
return predicates[0]
it = (
pc.field(name) == pc_expr(v, str_as_lit=True) for name, v in constraints.items()
)
return reduce(operator.and_, chain(predicates, it))


# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`)
def table_source(native: pa.Table, /) -> Decl:
"""A Source node which accepts a table."""
return Decl("table_source", options=pac.TableSourceNodeOptions(native))


def _aggregate(agg_specs: Iterable[AggSpec], /, keys: AggKeys = None) -> Decl:
# NOTE: See https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_acero.pyx#L167-L192
aggs: Incomplete = agg_specs
keys_: Incomplete = keys
return Decl("aggregate", pac.AggregateNodeOptions(aggs, keys=keys_))


# TODO @dangotbanned: Plan
# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`)
def aggregate(aggs: Iterable[AggSpec], /) -> Decl:
"""Scalar aggregate.

Reduce an array or scalar input to a single scalar output (e.g. computing the mean of a column)
"""
return _aggregate(aggs)


# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`)
def group_by(keys: AggKeys, aggs: Iterable[AggSpec], /) -> Decl:
"""Hash aggregate.

Like GROUP BY in SQL and first partition data based on one or more key columns,
then reduce the data in each partition.
"""
return _aggregate(aggs, keys=keys)


def filter(*predicates: Expr, **constraints: IntoExpr) -> Decl:
"""Selects rows where all expressions evaluate to True.

Arguments:
predicates: [`Expression`](s) which must all have a return type of boolean.
constraints: Column filters; use `name = value` to filter columns by the supplied value.

Notes:
- Uses logic similar to [`polars`] for an AND-reduction
- Elements where the filter does not evaluate to True are discarded, **including nulls**

[`Expression`]: https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html
[`polars`]: https://github.com/pola-rs/polars/blob/d0914d416ce4e1dfcb5f946875ffd1181e31c493/py-polars/polars/_utils/parse/expr.py#L199-L242
"""
expr = _parse_all_horizontal(predicates, constraints)
return Decl("filter", options=pac.FilterNodeOptions(expr))


# TODO @dangotbanned: Plan
def select(*exprs: IntoExpr, **named_exprs: IntoExpr) -> Decl:
raise NotImplementedError


# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`)
def project(**named_exprs: Expr) -> Decl:
"""Make a node which executes expressions on input batches, producing batches of the same length with new columns.

This is the option class for the "project" node factory.

The "project" operation rearranges, deletes, transforms, and
creates columns. Each output column is computed by evaluating
an expression against the source record batch. These must be
scalar expressions (expressions consisting of scalar literals,
field references and scalar functions, i.e. elementwise functions
that return one value for each input row independent of the value
of all other rows).
"""
# NOTE: Both just need to be sized and iterable
names: Incomplete = named_exprs.keys()
exprs: Incomplete = named_exprs.values()
return Decl("project", options=pac.ProjectNodeOptions(exprs, names))


# TODO @dangotbanned: Find which option class this uses
def order_by(
sort_keys: tuple[tuple[str, Literal["ascending", "descending"]], ...] = (),
*,
null_placement: Literal["at_start", "at_end"] = "at_end",
) -> Decl:
return Decl(
"order_by", pac.OrderByNodeOptions(sort_keys, null_placement=null_placement)
)


# TODO @dangotbanned: Docs
def collect(*declarations: Decl, use_threads: bool = True) -> pa.Table:
# NOTE: stubs + docs say `list`, but impl allows any iterable
decls: Incomplete = declarations
return Decl.from_sequence(decls).to_table(use_threads=use_threads)


# NOTE: Composite functions are suffixed with `_table`
def group_by_table(
native: pa.Table, keys: AggKeys, aggs: Iterable[AggSpec], *, use_threads: bool
) -> pa.Table:
"""Adapted from [`pa.TableGroupBy.aggregate`] and [`pa.acero._group_by`].

- Backport of [apache/arrow#36768].
- `first` and `last` were [broken in `pyarrow==13`].
- Also allows us to specify our own aliases for aggregate output columns.
- Fixes [narwhals-dev/narwhals#1612]

[`pa.TableGroupBy.aggregate`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/table.pxi#L6600-L6626
[`pa.acero._group_by`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/acero.py#L412-L418
[apache/arrow#36768]: https://github.com/apache/arrow/pull/36768
[broken in `pyarrow==13`]: https://github.com/apache/arrow/issues/36709
[narwhals-dev/narwhals#1612]: https://github.com/narwhals-dev/narwhals/issues/1612
"""
return collect(table_source(native), group_by(keys, aggs), use_threads=use_threads)


# TODO @dangotbanned: Docs?
def filter_table(native: pa.Table, *predicates: Expr, **constraints: Any) -> pa.Table:
return collect(table_source(native), filter(*predicates, **constraints))
Loading
Loading