Skip to content

Commit 79373f9

Browse files
refactor: Simplify compliant Series.hist (#2839)
* WIP * pandas-like looks good * simplify a bunch more * fix typing? * pyarrow, polars and compliant as well * ignore banned import * ignore banned import * refactor: Reuse `_polars.utils.BACKEND_VERSION` * address use of numpy/cupy * pragma: no cover all old polars paths * refactor: Add `zeros` * perf: Use `is_null(nan_is_null=True)` * refactor _prepare_bins * remove unnecessary cast * polars: use fill_nan and better typing * refactor hist into _hist_from_bins and _hist_from_bin_count * no cover old polars, unstable decorator * perf: `ArrowSeries._is_empty_series` w/ fewer ops #2839 (comment) * test: Start trying to clean up Related #2839 (comment) Finding these tests very difficult to follow * refactor: Getting creative w/ `polars` Trying to move as much as possible into expressions * refactor: even more expressions! * float | int -> float * apply suggestions * rename complaint functionalities * fix typo * refactor(suggestion): Try out generalizing everything - If this were only used in the `pandas` case, it is just more code - But we can reuse large chunks of it and extract others into new public api * rm unnecessary renaming * conditional renaming * one more try * refactor: add compliant level parent class for `hist` (#2882) * refactor: add compliant level parent class * fix(typing): remove unused ignores * fix(typing): Use `pa.Int64Array` * fix(typing): `Generic` -> `Protocol` #2882 (comment) * fix(typing): Resolve most invariance issues #2882 (comment) * chore(typing): Ignore `linspace` for now I wanna move this back into `pyarrow` anyway #2839 (comment) * docs(typing): Explain why `cast` bins * move histdata into type-checking block * chore(typing): `CompliantSeries` -> `EagerSeries` * chore(typing): `Any` -> `EagerDataFrameAny` * docs(typing): note constructor issue this one gets me every time * rm '_' prefix, stay native in _linear_space --------- Co-authored-by: dangotbanned <125183946+dangotbanned@users.noreply.github.com> --------- Co-authored-by: dangotbanned <125183946+dangotbanned@users.noreply.github.com>
1 parent 5d60ef5 commit 79373f9

File tree

10 files changed

+499
-323
lines changed

10 files changed

+499
-323
lines changed

narwhals/_arrow/series.py

Lines changed: 110 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, cast, overload
3+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
44

55
import pyarrow as pa
66
import pyarrow.compute as pc
@@ -20,8 +20,9 @@
2020
native_to_narwhals_dtype,
2121
nulls_like,
2222
pad_series,
23+
zeros,
2324
)
24-
from narwhals._compliant import EagerSeries
25+
from narwhals._compliant import EagerSeries, EagerSeriesHist
2526
from narwhals._expression_parsing import ExprKind
2627
from narwhals._typing_compat import assert_never
2728
from narwhals._utils import (
@@ -39,7 +40,7 @@
3940

4041
import pandas as pd
4142
import polars as pl
42-
from typing_extensions import Self, TypeIs
43+
from typing_extensions import Self, TypeAlias, TypeIs
4344

4445
from narwhals._arrow.dataframe import ArrowDataFrame
4546
from narwhals._arrow.namespace import ArrowNamespace
@@ -51,10 +52,12 @@
5152
Incomplete,
5253
NullPlacement,
5354
Order,
55+
ScalarAny,
5456
TieBreaker,
5557
_AsPyType,
5658
_BasicDataType,
5759
)
60+
from narwhals._compliant.series import HistData
5861
from narwhals._utils import Version, _LimitedContext
5962
from narwhals.dtypes import DType
6063
from narwhals.typing import (
@@ -74,6 +77,10 @@
7477
_SliceIndex,
7578
)
7679

80+
ArrowHistData: TypeAlias = (
81+
"HistData[ChunkedArrayAny, list[ScalarAny] | pa.Int64Array | list[float]]"
82+
)
83+
7784

7885
# TODO @dangotbanned: move into `_arrow.utils`
7986
# Lots of modules are importing inline
@@ -320,8 +327,6 @@ def mean(self, *, _return_py_scalar: bool = True) -> float:
320327
return maybe_extract_py_scalar(pc.mean(self.native), _return_py_scalar)
321328

322329
def median(self, *, _return_py_scalar: bool = True) -> float:
323-
from narwhals.exceptions import InvalidOperationError
324-
325330
if not self.dtype.is_numeric():
326331
msg = "`median` operation not supported for non-numeric input type."
327332
raise InvalidOperationError(msg)
@@ -1021,101 +1026,22 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self:
10211026
result = pc.if_else(null_mask, lit(None, rank.type), rank)
10221027
return self._with_native(result)
10231028

1024-
def hist( # noqa: C901, PLR0912, PLR0915
1025-
self,
1026-
bins: list[float | int] | None,
1027-
*,
1028-
bin_count: int | None,
1029-
include_breakpoint: bool,
1029+
def hist_from_bins(
1030+
self, bins: list[float], *, include_breakpoint: bool
10301031
) -> ArrowDataFrame:
1031-
import numpy as np # ignore-banned-import
1032-
1033-
from narwhals._arrow.dataframe import ArrowDataFrame
1034-
1035-
def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa: ANN202
1036-
d = pc.min_max(self.native)
1037-
lower, upper = d["min"].as_py(), d["max"].as_py()
1038-
if lower == upper:
1039-
lower -= 0.5
1040-
upper += 0.5
1041-
bins = np.linspace(lower, upper, bin_count + 1)
1042-
return _hist_from_bins(bins)
1043-
1044-
def _hist_from_bins(bins: Sequence[int | float]): # type: ignore[no-untyped-def] # noqa: ANN202
1045-
bin_indices = np.searchsorted(bins, self.native, side="left")
1046-
bin_indices = pc.if_else( # lowest bin is inclusive
1047-
pc.equal(self.native, lit(bins[0])), 1, bin_indices
1048-
)
1049-
1050-
# align unique categories and counts appropriately
1051-
obs_cats, obs_counts = np.unique(bin_indices, return_counts=True)
1052-
obj_cats = np.arange(1, len(bins))
1053-
counts = np.zeros_like(obj_cats)
1054-
counts[np.isin(obj_cats, obs_cats)] = obs_counts[np.isin(obs_cats, obj_cats)]
1055-
1056-
bin_right = bins[1:]
1057-
return counts, bin_right
1058-
1059-
counts: Sequence[int | float | pa.Scalar[Any]] | np.typing.ArrayLike
1060-
bin_right: Sequence[int | float | pa.Scalar[Any]] | np.typing.ArrayLike
1061-
1062-
data_count = pc.sum(
1063-
pc.invert(pc.or_(pc.is_nan(self.native), pc.is_null(self.native))).cast(
1064-
pa.uint8()
1065-
),
1066-
min_count=0,
1032+
return (
1033+
_ArrowHist.from_series(self, include_breakpoint=include_breakpoint)
1034+
.with_bins(bins)
1035+
.to_frame()
10671036
)
1068-
if bins is not None:
1069-
if len(bins) < 2:
1070-
counts, bin_right = [], []
1071-
1072-
elif data_count == pa.scalar(0, type=pa.uint64()): # type:ignore[comparison-overlap]
1073-
counts = np.zeros(len(bins) - 1)
1074-
bin_right = bins[1:]
1075-
1076-
elif len(bins) == 2:
1077-
counts = [
1078-
pc.sum(
1079-
pc.and_(
1080-
pc.greater_equal(self.native, lit(float(bins[0]))),
1081-
pc.less_equal(self.native, lit(float(bins[1]))),
1082-
).cast(pa.uint8())
1083-
)
1084-
]
1085-
bin_right = [bins[-1]]
1086-
else:
1087-
counts, bin_right = _hist_from_bins(bins)
1088-
1089-
elif bin_count is not None:
1090-
if bin_count == 0:
1091-
counts, bin_right = [], []
1092-
elif data_count == pa.scalar(0, type=pa.uint64()): # type:ignore[comparison-overlap]
1093-
counts, bin_right = (
1094-
np.zeros(bin_count),
1095-
np.linspace(0, 1, bin_count + 1)[1:],
1096-
)
1097-
elif bin_count == 1:
1098-
d = pc.min_max(self.native)
1099-
lower, upper = d["min"], d["max"]
1100-
if lower == upper:
1101-
counts, bin_right = [data_count], [pc.add(upper, pa.scalar(0.5))]
1102-
else:
1103-
counts, bin_right = [data_count], [upper]
1104-
else:
1105-
counts, bin_right = _hist_from_bin_count(bin_count)
11061037

1107-
else: # pragma: no cover
1108-
# caller guarantees that either bins or bin_count is specified
1109-
msg = "must provide one of `bin_count` or `bins`"
1110-
raise InvalidOperationError(msg)
1111-
1112-
data: dict[str, Any] = {}
1113-
if include_breakpoint:
1114-
data["breakpoint"] = bin_right
1115-
data["count"] = counts
1116-
1117-
return ArrowDataFrame(
1118-
pa.Table.from_pydict(data), version=self._version, validate_column_names=True
1038+
def hist_from_bin_count(
1039+
self, bin_count: int, *, include_breakpoint: bool
1040+
) -> ArrowDataFrame:
1041+
return (
1042+
_ArrowHist.from_series(self, include_breakpoint=include_breakpoint)
1043+
.with_bin_count(bin_count)
1044+
.to_frame()
11191045
)
11201046

11211047
def __iter__(self) -> Iterator[Any]:
@@ -1135,8 +1061,6 @@ def __contains__(self, other: Any) -> bool:
11351061
pc.is_in(other_, self.native), return_py_scalar=True
11361062
)
11371063
except (ArrowInvalid, ArrowNotImplementedError, ArrowTypeError) as exc:
1138-
from narwhals.exceptions import InvalidOperationError
1139-
11401064
msg = f"Unable to compare other of type {type(other)} with series of type {self.dtype}."
11411065
raise InvalidOperationError(msg) from exc
11421066

@@ -1170,3 +1094,90 @@ def struct(self) -> ArrowSeriesStructNamespace:
11701094
return ArrowSeriesStructNamespace(self)
11711095

11721096
ewm_mean = not_implemented()
1097+
1098+
1099+
class _ArrowHist(
1100+
EagerSeriesHist["ChunkedArrayAny", "list[ScalarAny] | pa.Int64Array | list[float]"]
1101+
):
1102+
_series: ArrowSeries
1103+
1104+
def to_frame(self) -> ArrowDataFrame:
1105+
# NOTE: Constructor typing is too strict for `TypedDict`
1106+
table: Incomplete = pa.Table.from_pydict
1107+
from_native = self._series.__narwhals_namespace__()._dataframe.from_native
1108+
return from_native(table(self._data), context=self._series)
1109+
1110+
# NOTE: *Could* be handled at narwhals-level
1111+
def is_empty_series(self) -> bool:
1112+
# NOTE: `ChunkedArray.combine_chunks` returns the concrete array type
1113+
# Stubs say `Array[pa.BooleanScalar]`, which is missing properties
1114+
# https://github.com/zen-xu/pyarrow-stubs/blob/6bedee748bc74feb8513b24bf43d64b24c7fddc8/pyarrow-stubs/__lib_pxi/array.pyi#L2395-L2399
1115+
is_null = self.native.is_null(nan_is_null=True)
1116+
arr = cast("pa.BooleanArray", is_null.combine_chunks())
1117+
return arr.false_count == 0
1118+
1119+
# NOTE: *Could* be handled at narwhals-level, **iff** we add `nw.repeat`, `nw.linear_space`
1120+
# See https://github.com/narwhals-dev/narwhals/pull/2839#discussion_r2215630696
1121+
def series_empty(self, arg: int | list[float], /) -> ArrowHistData:
1122+
count = self._zeros(arg)
1123+
if self._breakpoint:
1124+
return {"breakpoint": self._calculate_breakpoint(arg), "count": count}
1125+
return {"count": count}
1126+
1127+
def _zeros(self, arg: int | list[float], /) -> pa.Int64Array:
1128+
return zeros(arg) if isinstance(arg, int) else zeros(len(arg) - 1)
1129+
1130+
def _linear_space(
1131+
self,
1132+
start: float,
1133+
end: float,
1134+
num_samples: int,
1135+
*,
1136+
closed: Literal["both", "none"] = "both",
1137+
) -> _1DArray:
1138+
from numpy import linspace # ignore-banned-import
1139+
1140+
return linspace(start=start, stop=end, num=num_samples, endpoint=closed == "both")
1141+
1142+
def _calculate_bins(self, bin_count: int) -> _1DArray:
1143+
"""Prepare bins for histogram calculation from bin_count."""
1144+
d = pc.min_max(self.native)
1145+
lower, upper = d["min"].as_py(), d["max"].as_py()
1146+
if lower == upper:
1147+
lower -= 0.5
1148+
upper += 0.5
1149+
return self._linear_space(lower, upper, bin_count + 1)
1150+
1151+
def _calculate_hist(self, bins: list[float] | _1DArray) -> ArrowHistData:
1152+
ser = self.native
1153+
# NOTE: `mypy` refuses to resolve `ndarray.__getitem__`
1154+
# Previously annotated as `list[float]`, but
1155+
# - wasn't accurate to how we implemented it
1156+
# - `pa.scalar` overloads fail to match on `float | np.float64` (but runtime is fine)
1157+
bins = cast("list[float]", bins)
1158+
# Handle single bin case
1159+
if len(bins) == 2:
1160+
is_between_bins = pc.and_(
1161+
pc.greater_equal(ser, lit(bins[0])), pc.less_equal(ser, lit(bins[1]))
1162+
)
1163+
count = pc.sum(is_between_bins.cast(pa.uint8()))
1164+
if self._breakpoint:
1165+
return {"breakpoint": [bins[-1]], "count": [count]}
1166+
return {"count": [count]}
1167+
1168+
# Handle multiple bins
1169+
import numpy as np # ignore-banned-import
1170+
1171+
bin_indices = np.searchsorted(bins, ser, side="left")
1172+
# lowest bin is inclusive
1173+
bin_indices = pc.if_else(pc.equal(ser, lit(bins[0])), 1, bin_indices)
1174+
1175+
# Align unique categories and counts appropriately
1176+
obs_cats, obs_counts = np.unique(bin_indices, return_counts=True)
1177+
obj_cats = np.arange(1, len(bins))
1178+
counts = np.zeros_like(obj_cats)
1179+
counts[np.isin(obj_cats, obs_cats)] = obs_counts[np.isin(obs_cats, obj_cats)]
1180+
1181+
if self._breakpoint:
1182+
return {"breakpoint": bins[1:], "count": counts}
1183+
return {"count": counts}

narwhals/_arrow/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ def nulls_like(n: int, series: ArrowSeries) -> ArrayAny:
100100
return pa.nulls(n, series.native.type)
101101

102102

103+
def zeros(n: int, /) -> pa.Int64Array:
104+
return pa.repeat(0, n)
105+
106+
103107
@lru_cache(maxsize=16)
104108
def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: # noqa: C901, PLR0912
105109
dtypes = version.dtypes

narwhals/_compliant/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
EagerSeries,
3535
EagerSeriesCatNamespace,
3636
EagerSeriesDateTimeNamespace,
37+
EagerSeriesHist,
3738
EagerSeriesListNamespace,
3839
EagerSeriesNamespace,
3940
EagerSeriesStringNamespace,
@@ -82,6 +83,7 @@
8283
"EagerSeries",
8384
"EagerSeriesCatNamespace",
8485
"EagerSeriesDateTimeNamespace",
86+
"EagerSeriesHist",
8587
"EagerSeriesListNamespace",
8688
"EagerSeriesNamespace",
8789
"EagerSeriesStringNamespace",

0 commit comments

Comments
 (0)