11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , Any , cast , overload
3+ from typing import TYPE_CHECKING , Any , Literal , cast , overload
44
55import pyarrow as pa
66import pyarrow .compute as pc
2020 native_to_narwhals_dtype ,
2121 nulls_like ,
2222 pad_series ,
23+ zeros ,
2324)
24- from narwhals ._compliant import EagerSeries
25+ from narwhals ._compliant import EagerSeries , EagerSeriesHist
2526from narwhals ._expression_parsing import ExprKind
2627from narwhals ._typing_compat import assert_never
2728from narwhals ._utils import (
3940
4041 import pandas as pd
4142 import polars as pl
42- from typing_extensions import Self , TypeIs
43+ from typing_extensions import Self , TypeAlias , TypeIs
4344
4445 from narwhals ._arrow .dataframe import ArrowDataFrame
4546 from narwhals ._arrow .namespace import ArrowNamespace
5152 Incomplete ,
5253 NullPlacement ,
5354 Order ,
55+ ScalarAny ,
5456 TieBreaker ,
5557 _AsPyType ,
5658 _BasicDataType ,
5759 )
60+ from narwhals ._compliant .series import HistData
5861 from narwhals ._utils import Version , _LimitedContext
5962 from narwhals .dtypes import DType
6063 from narwhals .typing import (
7477 _SliceIndex ,
7578 )
7679
80+ ArrowHistData : TypeAlias = (
81+ "HistData[ChunkedArrayAny, list[ScalarAny] | pa.Int64Array | list[float]]"
82+ )
83+
7784
7885# TODO @dangotbanned: move into `_arrow.utils`
7986# Lots of modules are importing inline
@@ -320,8 +327,6 @@ def mean(self, *, _return_py_scalar: bool = True) -> float:
320327 return maybe_extract_py_scalar (pc .mean (self .native ), _return_py_scalar )
321328
322329 def median (self , * , _return_py_scalar : bool = True ) -> float :
323- from narwhals .exceptions import InvalidOperationError
324-
325330 if not self .dtype .is_numeric ():
326331 msg = "`median` operation not supported for non-numeric input type."
327332 raise InvalidOperationError (msg )
@@ -1021,101 +1026,22 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self:
10211026 result = pc .if_else (null_mask , lit (None , rank .type ), rank )
10221027 return self ._with_native (result )
10231028
1024- def hist ( # noqa: C901, PLR0912, PLR0915
1025- self ,
1026- bins : list [float | int ] | None ,
1027- * ,
1028- bin_count : int | None ,
1029- include_breakpoint : bool ,
1029+ def hist_from_bins (
1030+ self , bins : list [float ], * , include_breakpoint : bool
10301031 ) -> ArrowDataFrame :
1031- import numpy as np # ignore-banned-import
1032-
1033- from narwhals ._arrow .dataframe import ArrowDataFrame
1034-
1035- def _hist_from_bin_count (bin_count : int ): # type: ignore[no-untyped-def] # noqa: ANN202
1036- d = pc .min_max (self .native )
1037- lower , upper = d ["min" ].as_py (), d ["max" ].as_py ()
1038- if lower == upper :
1039- lower -= 0.5
1040- upper += 0.5
1041- bins = np .linspace (lower , upper , bin_count + 1 )
1042- return _hist_from_bins (bins )
1043-
1044- def _hist_from_bins (bins : Sequence [int | float ]): # type: ignore[no-untyped-def] # noqa: ANN202
1045- bin_indices = np .searchsorted (bins , self .native , side = "left" )
1046- bin_indices = pc .if_else ( # lowest bin is inclusive
1047- pc .equal (self .native , lit (bins [0 ])), 1 , bin_indices
1048- )
1049-
1050- # align unique categories and counts appropriately
1051- obs_cats , obs_counts = np .unique (bin_indices , return_counts = True )
1052- obj_cats = np .arange (1 , len (bins ))
1053- counts = np .zeros_like (obj_cats )
1054- counts [np .isin (obj_cats , obs_cats )] = obs_counts [np .isin (obs_cats , obj_cats )]
1055-
1056- bin_right = bins [1 :]
1057- return counts , bin_right
1058-
1059- counts : Sequence [int | float | pa .Scalar [Any ]] | np .typing .ArrayLike
1060- bin_right : Sequence [int | float | pa .Scalar [Any ]] | np .typing .ArrayLike
1061-
1062- data_count = pc .sum (
1063- pc .invert (pc .or_ (pc .is_nan (self .native ), pc .is_null (self .native ))).cast (
1064- pa .uint8 ()
1065- ),
1066- min_count = 0 ,
1032+ return (
1033+ _ArrowHist .from_series (self , include_breakpoint = include_breakpoint )
1034+ .with_bins (bins )
1035+ .to_frame ()
10671036 )
1068- if bins is not None :
1069- if len (bins ) < 2 :
1070- counts , bin_right = [], []
1071-
1072- elif data_count == pa .scalar (0 , type = pa .uint64 ()): # type:ignore[comparison-overlap]
1073- counts = np .zeros (len (bins ) - 1 )
1074- bin_right = bins [1 :]
1075-
1076- elif len (bins ) == 2 :
1077- counts = [
1078- pc .sum (
1079- pc .and_ (
1080- pc .greater_equal (self .native , lit (float (bins [0 ]))),
1081- pc .less_equal (self .native , lit (float (bins [1 ]))),
1082- ).cast (pa .uint8 ())
1083- )
1084- ]
1085- bin_right = [bins [- 1 ]]
1086- else :
1087- counts , bin_right = _hist_from_bins (bins )
1088-
1089- elif bin_count is not None :
1090- if bin_count == 0 :
1091- counts , bin_right = [], []
1092- elif data_count == pa .scalar (0 , type = pa .uint64 ()): # type:ignore[comparison-overlap]
1093- counts , bin_right = (
1094- np .zeros (bin_count ),
1095- np .linspace (0 , 1 , bin_count + 1 )[1 :],
1096- )
1097- elif bin_count == 1 :
1098- d = pc .min_max (self .native )
1099- lower , upper = d ["min" ], d ["max" ]
1100- if lower == upper :
1101- counts , bin_right = [data_count ], [pc .add (upper , pa .scalar (0.5 ))]
1102- else :
1103- counts , bin_right = [data_count ], [upper ]
1104- else :
1105- counts , bin_right = _hist_from_bin_count (bin_count )
11061037
1107- else : # pragma: no cover
1108- # caller guarantees that either bins or bin_count is specified
1109- msg = "must provide one of `bin_count` or `bins`"
1110- raise InvalidOperationError (msg )
1111-
1112- data : dict [str , Any ] = {}
1113- if include_breakpoint :
1114- data ["breakpoint" ] = bin_right
1115- data ["count" ] = counts
1116-
1117- return ArrowDataFrame (
1118- pa .Table .from_pydict (data ), version = self ._version , validate_column_names = True
1038+ def hist_from_bin_count (
1039+ self , bin_count : int , * , include_breakpoint : bool
1040+ ) -> ArrowDataFrame :
1041+ return (
1042+ _ArrowHist .from_series (self , include_breakpoint = include_breakpoint )
1043+ .with_bin_count (bin_count )
1044+ .to_frame ()
11191045 )
11201046
11211047 def __iter__ (self ) -> Iterator [Any ]:
@@ -1135,8 +1061,6 @@ def __contains__(self, other: Any) -> bool:
11351061 pc .is_in (other_ , self .native ), return_py_scalar = True
11361062 )
11371063 except (ArrowInvalid , ArrowNotImplementedError , ArrowTypeError ) as exc :
1138- from narwhals .exceptions import InvalidOperationError
1139-
11401064 msg = f"Unable to compare other of type { type (other )} with series of type { self .dtype } ."
11411065 raise InvalidOperationError (msg ) from exc
11421066
@@ -1170,3 +1094,90 @@ def struct(self) -> ArrowSeriesStructNamespace:
11701094 return ArrowSeriesStructNamespace (self )
11711095
11721096 ewm_mean = not_implemented ()
1097+
1098+
1099+ class _ArrowHist (
1100+ EagerSeriesHist ["ChunkedArrayAny" , "list[ScalarAny] | pa.Int64Array | list[float]" ]
1101+ ):
1102+ _series : ArrowSeries
1103+
1104+ def to_frame (self ) -> ArrowDataFrame :
1105+ # NOTE: Constructor typing is too strict for `TypedDict`
1106+ table : Incomplete = pa .Table .from_pydict
1107+ from_native = self ._series .__narwhals_namespace__ ()._dataframe .from_native
1108+ return from_native (table (self ._data ), context = self ._series )
1109+
1110+ # NOTE: *Could* be handled at narwhals-level
1111+ def is_empty_series (self ) -> bool :
1112+ # NOTE: `ChunkedArray.combine_chunks` returns the concrete array type
1113+ # Stubs say `Array[pa.BooleanScalar]`, which is missing properties
1114+ # https://github.com/zen-xu/pyarrow-stubs/blob/6bedee748bc74feb8513b24bf43d64b24c7fddc8/pyarrow-stubs/__lib_pxi/array.pyi#L2395-L2399
1115+ is_null = self .native .is_null (nan_is_null = True )
1116+ arr = cast ("pa.BooleanArray" , is_null .combine_chunks ())
1117+ return arr .false_count == 0
1118+
1119+ # NOTE: *Could* be handled at narwhals-level, **iff** we add `nw.repeat`, `nw.linear_space`
1120+ # See https://github.com/narwhals-dev/narwhals/pull/2839#discussion_r2215630696
1121+ def series_empty (self , arg : int | list [float ], / ) -> ArrowHistData :
1122+ count = self ._zeros (arg )
1123+ if self ._breakpoint :
1124+ return {"breakpoint" : self ._calculate_breakpoint (arg ), "count" : count }
1125+ return {"count" : count }
1126+
1127+ def _zeros (self , arg : int | list [float ], / ) -> pa .Int64Array :
1128+ return zeros (arg ) if isinstance (arg , int ) else zeros (len (arg ) - 1 )
1129+
1130+ def _linear_space (
1131+ self ,
1132+ start : float ,
1133+ end : float ,
1134+ num_samples : int ,
1135+ * ,
1136+ closed : Literal ["both" , "none" ] = "both" ,
1137+ ) -> _1DArray :
1138+ from numpy import linspace # ignore-banned-import
1139+
1140+ return linspace (start = start , stop = end , num = num_samples , endpoint = closed == "both" )
1141+
1142+ def _calculate_bins (self , bin_count : int ) -> _1DArray :
1143+ """Prepare bins for histogram calculation from bin_count."""
1144+ d = pc .min_max (self .native )
1145+ lower , upper = d ["min" ].as_py (), d ["max" ].as_py ()
1146+ if lower == upper :
1147+ lower -= 0.5
1148+ upper += 0.5
1149+ return self ._linear_space (lower , upper , bin_count + 1 )
1150+
1151+ def _calculate_hist (self , bins : list [float ] | _1DArray ) -> ArrowHistData :
1152+ ser = self .native
1153+ # NOTE: `mypy` refuses to resolve `ndarray.__getitem__`
1154+ # Previously annotated as `list[float]`, but
1155+ # - wasn't accurate to how we implemented it
1156+ # - `pa.scalar` overloads fail to match on `float | np.float64` (but runtime is fine)
1157+ bins = cast ("list[float]" , bins )
1158+ # Handle single bin case
1159+ if len (bins ) == 2 :
1160+ is_between_bins = pc .and_ (
1161+ pc .greater_equal (ser , lit (bins [0 ])), pc .less_equal (ser , lit (bins [1 ]))
1162+ )
1163+ count = pc .sum (is_between_bins .cast (pa .uint8 ()))
1164+ if self ._breakpoint :
1165+ return {"breakpoint" : [bins [- 1 ]], "count" : [count ]}
1166+ return {"count" : [count ]}
1167+
1168+ # Handle multiple bins
1169+ import numpy as np # ignore-banned-import
1170+
1171+ bin_indices = np .searchsorted (bins , ser , side = "left" )
1172+ # lowest bin is inclusive
1173+ bin_indices = pc .if_else (pc .equal (ser , lit (bins [0 ])), 1 , bin_indices )
1174+
1175+ # Align unique categories and counts appropriately
1176+ obs_cats , obs_counts = np .unique (bin_indices , return_counts = True )
1177+ obj_cats = np .arange (1 , len (bins ))
1178+ counts = np .zeros_like (obj_cats )
1179+ counts [np .isin (obj_cats , obs_cats )] = obs_counts [np .isin (obs_cats , obj_cats )]
1180+
1181+ if self ._breakpoint :
1182+ return {"breakpoint" : bins [1 :], "count" : counts }
1183+ return {"count" : counts }
0 commit comments