-
Notifications
You must be signed in to change notification settings - Fork 175
refactor: Simplify compliant Series.hist
#2839
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Series.histSeries.hist
|
Ok I am tagging it as ready for review, but I am cheating on type hints. Will drop some inline comments for the daily struggle with them 😭 |
Series.histSeries.hist
narwhals/_arrow/series.py
Outdated
| elif pc.sum( | ||
| pc.invert(pc.or_(pc.is_nan(self.native), pc.is_null(self.native))).cast( | ||
| pa.uint8() | ||
| pa.uint64() | ||
| ), | ||
| min_count=0, | ||
| ) == pa.scalar(0, type=pa.uint64()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I tried to put
pc.sum(...)into its own method, but mypy kept complaining that the return type should have beenpa.Scalar[...]where...wasUint8Typewhich does not exist? at least I could not find it in the docs, nor inpanamespace, notpa.dtypes - A different cheat would have been to return after moving it to a python int via
.as_py() - The point above would also make the comparison simpler - I don't know how it was not complaining before, but now comparing uint8 with uint64 was raising another issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to put
pc.sum(...)into its own method, but mypy kept complaining that the return type should have beenpa.Scalar[...]where ... wasUint8Typewhich does not exist? at least I could not find it in the docs, nor inpanamespace, notpa.dtypes
This was corrected in a more recent version of pyarrow-stubs (zen-xu/pyarrow-stubs#230)
I've been putting off upgrading our pinned version, as:
- they aren't tested with
mypyanymore - the author prefers to ignore
pyrightrules that I mentioned caused us downstream issues
narwhals/_arrow/series.py
Outdated
| bin_right: Sequence[int | float | pa.Scalar[Any]] | np.typing.ArrayLike | ||
|
|
||
| data_count = pc.sum( | ||
| data: dict[str, Any] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really know where to start with this. My best guess was dict[str, Sequence[int | float] | _1DArray | ArrayAny] but still not a happy ending from _hist_from_bins below
narwhals/_polars/series.py
Outdated
| upper += 0.5 | ||
|
|
||
| width = (upper - lower) / bin_count | ||
| bins = pl.int_range(0, bin_count + 1, eager=True) * width + lower # type: ignore[assignment] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previous version was casting to list, which is a bit of a missed opportunity to keep everything native
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you thought about reimplementing the whole of the polars one using expressions?
I think this is the most complicated Series method and I'd guess could benefit the most performance-wise 🫣
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really! As the main issue with polars is "just" managing all the different versions before and after the call to .hist(...), I didn't invest much time on it 🙈
| """Prepare bins based on backend version compatibility. | ||
| polars <1.15 does not adjust the bins when they have equivalent min/max | ||
| polars <1.5 with bin_count=... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit confused by this line comment, but I kept it as it was
- If this were only used in the `pandas` case, it is just more code - But we can reuse large chunks of it and extract others into new public api
|
Hey @FBruzzesi 👋 I've been experimenting with refactor(suggestion): Try out generalizing everything So we should be able to add:
And there's a lot in there that I think could either move up into:
But most importantly, this is only a suggestion that came out of finally understanding |
|
Hey @dangotbanned thanks for suggesting #2839 (comment) From a first look, I really like it. I have a few open questions also considering the Discord exchange we had about
|
|
Thanks for all the great questions 😍!
My eyes are about fall out - but for just this part - I was thinking the shortest path to At least for import polars as pl
import narwhals as nw
data = {"a": [1, 3, 8, 8, 2, 1, 3]}
bins: list[float] = [1, 2, 3]
df = pl.DataFrame(data)
nw_ser = nw.from_native(df).get_column("a")
nw_hist = nw_ser.hist(bins, include_breakpoint=True)
>>> nw_hist
┌──────────────────────┐
| Narwhals DataFrame |
|----------------------|
|shape: (2, 2) |
|┌────────────┬───────┐|
|│ breakpoint ┆ count │|
|│ --- ┆ --- │|
|│ f64 ┆ u32 │|
|╞════════════╪═══════╡|
|│ 2.0 ┆ 3 │|
|│ 3.0 ┆ 2 │|
|└────────────┴───────┘|
└──────────────────────┘
pl_hist_expr = df.select(pl.col("a").hist(bins, include_breakpoint=True))
nw_pl_hist_struct = nw.from_native(
nw_hist.to_polars().to_struct("a"), series_only=True
).to_frame()
nw_pa_hist_struct = (
nw.from_native(nw_hist.to_arrow().to_struct_array(), series_only=True)
.alias("a")
.to_frame()
)
>>> pl_hist_expr, nw_pl_hist_struct, nw_pa_hist_struct
(shape: (2, 1)
┌───────────┐
│ a │
│ --- │
│ struct[2] │
╞═══════════╡
│ {2.0,3} │
│ {3.0,2} │
└───────────┘,
┌──────────────────┐
|Narwhals DataFrame|
|------------------|
| shape: (2, 1) |
| ┌───────────┐ |
| │ a │ |
| │ --- │ |
| │ struct[2] │ |
| ╞═══════════╡ |
| │ {2.0,3} │ |
| │ {3.0,2} │ |
| └───────────┘ |
└──────────────────┘,
┌────────────────────────────────────────────┐
| Narwhals DataFrame |
|--------------------------------------------|
|pyarrow.Table |
|a: struct<breakpoint: double, count: uint32>|
| child 0, breakpoint: double |
| child 1, count: uint32 |
|---- |
|a: [ |
| -- is_valid: all not null |
| -- child 0 type: double |
|[2,3] |
| -- child 1 type: uint32 |
|[3,2]] |
└────────────────────────────────────────────┘)Note Proposed API 172729: |
Thanks @FBruzzesi! Okay easiest one out of the way first:
To merging this PR? Absolutely not! 😄 I have a few thoughts on the If you meant, are things like I think it would be a great candidate for exposing in |
|
(refactor(suggestion): Try out generalizing everything) collects everything into that class to make the diff easier 😅 I think these parts are the same across backends, so to narwhals/narwhals/_pandas_like/series.py Lines 1103 to 1106 in 1334849
narwhals/narwhals/_pandas_like/series.py Lines 1112 to 1115 in 1334849
Anywhere I left a note like this, could become that method: narwhals/narwhals/_pandas_like/series.py Lines 1125 to 1126 in 1334849
narwhals/narwhals/_pandas_like/series.py Lines 1143 to 1144 in 1334849
Things like this are dependent on the above narwhals/narwhals/_pandas_like/series.py Lines 1095 to 1097 in 1334849
I think a good example of splitting things across layers to implement is the work Its all about trying to spot the cases where we're repeating ourselves 😄
|
| from narwhals.series import Series | |
| msg = ( | |
| f"Unexpected type for `DataFrame.__getitem__`, got: {type(item)}.\n\n" | |
| "Hints:\n" | |
| "- use `df.item` to select a single item.\n" | |
| "- Use `df[indices, :]` to select rows positionally.\n" | |
| "- Use `df.filter(mask)` to filter rows based on a boolean mask." | |
| ) | |
| if isinstance(item, tuple): | |
| if len(item) > 2: | |
| tuple_msg = ( | |
| "Tuples cannot be passed to DataFrame.__getitem__ directly.\n\n" | |
| "Hint: instead of `df[indices]`, did you mean `df[indices, :]`?" | |
| ) | |
| raise TypeError(tuple_msg) | |
| rows = None if not item or is_slice_none(item[0]) else item[0] | |
| columns = None if len(item) < 2 or is_slice_none(item[1]) else item[1] | |
| if rows is None and columns is None: | |
| return self | |
| elif is_index_selector(item): | |
| rows = item | |
| columns = None | |
| elif is_sequence_like(item) or isinstance(item, (slice, str)): | |
| rows = None | |
| columns = item | |
| else: | |
| raise TypeError(msg) | |
| if isinstance(rows, str): | |
| raise TypeError(msg) | |
| compliant = self._compliant_frame | |
| if isinstance(columns, (int, str)): | |
| if isinstance(rows, int): | |
| return self.item(rows, columns) | |
| col_name = columns if isinstance(columns, str) else self.columns[columns] | |
| series = self.get_column(col_name) | |
| return series[rows] if rows is not None else series | |
| if isinstance(rows, Series): | |
| rows = rows._compliant_series | |
| if isinstance(columns, Series): | |
| columns = columns._compliant_series | |
| if rows is None: | |
| return self._with_compliant(compliant[:, columns]) | |
| if columns is None: | |
| return self._with_compliant(compliant[rows, :]) | |
| return self._with_compliant(compliant[rows, columns]) |
EagerDataFrame.__getitem__
narwhals/narwhals/_compliant/dataframe.py
Lines 439 to 489 in 80c5fc2
| def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... | |
| def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... | |
| def _select_multi_index( | |
| self, columns: SizedMultiIndexSelector[NativeSeriesT] | |
| ) -> Self: ... | |
| def _select_multi_name( | |
| self, columns: SizedMultiNameSelector[NativeSeriesT] | |
| ) -> Self: ... | |
| def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ... | |
| def _select_slice_name(self, columns: _SliceName) -> Self: ... | |
| def __getitem__( # noqa: C901, PLR0912 | |
| self, | |
| item: tuple[ | |
| SingleIndexSelector | MultiIndexSelector[EagerSeriesT], | |
| MultiColSelector[EagerSeriesT], | |
| ], | |
| ) -> Self: | |
| rows, columns = item | |
| compliant = self | |
| if not is_slice_none(columns): | |
| if isinstance(columns, Sized) and len(columns) == 0: | |
| return compliant.select() | |
| if is_index_selector(columns): | |
| if is_slice_index(columns) or is_range(columns): | |
| compliant = compliant._select_slice_index(columns) | |
| elif is_compliant_series(columns): | |
| compliant = self._select_multi_index(columns.native) | |
| else: | |
| compliant = compliant._select_multi_index(columns) | |
| elif isinstance(columns, slice): | |
| compliant = compliant._select_slice_name(columns) | |
| elif is_compliant_series(columns): | |
| compliant = self._select_multi_name(columns.native) | |
| elif is_sequence_like(columns): | |
| compliant = self._select_multi_name(columns) | |
| else: | |
| assert_never(columns) | |
| if not is_slice_none(rows): | |
| if isinstance(rows, int): | |
| compliant = compliant._gather([rows]) | |
| elif isinstance(rows, (slice, range)): | |
| compliant = compliant._gather_slice(rows) | |
| elif is_compliant_series(rows): | |
| compliant = compliant._gather(rows.native) | |
| elif is_sized_multi_index_selector(rows): | |
| compliant = compliant._gather(rows) | |
| else: | |
| assert_never(rows) | |
| return compliant |
dangotbanned
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks x7 @FBruzzesi!
I've got two very minor bits on the polars side.
As mentioned in (#2839 (comment)) - I'm happy for everything else to be follow-ups
I wanna move this back into `pyarrow` anyway #2839 (comment)
* refactor: add compliant level parent class * fix(typing): remove unused ignores * fix(typing): Use `pa.Int64Array` * fix(typing): `Generic` -> `Protocol` #2882 (comment) * fix(typing): Resolve most invariance issues #2882 (comment) * chore(typing): Ignore `linspace` for now I wanna move this back into `pyarrow` anyway #2839 (comment) * docs(typing): Explain why `cast` bins * move histdata into type-checking block * chore(typing): `CompliantSeries` -> `EagerSeries` * chore(typing): `Any` -> `EagerDataFrameAny` * docs(typing): note constructor issue this one gets me every time * rm '_' prefix, stay native in _linear_space --------- Co-authored-by: dangotbanned <125183946+dangotbanned@users.noreply.github.com>
dangotbanned
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing everything in (#2839 (review)) @FBruzzesi


What type of PR is this? (check all applicable)
Related issues
*Series.hist#2487Checklist
If you have comments or can explain your changes, please do so below