Skip to content

Commit 434fb5d

Browse files
authored
fix: fix pandas.cut errors with empty bins (#1499)
* fix: fix pandas.cut errors with empty bins * nit * refactor if branches for more readable
1 parent 802183d commit 434fb5d

File tree

4 files changed

+89
-100
lines changed

4 files changed

+89
-100
lines changed

bigframes/core/reshape/tile.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,32 +41,33 @@ def cut(
4141
right: typing.Optional[bool] = True,
4242
labels: typing.Union[typing.Iterable[str], bool, None] = None,
4343
) -> bigframes.series.Series:
44-
if isinstance(bins, int) and bins <= 0:
45-
raise ValueError("`bins` should be a positive integer.")
46-
47-
# TODO: Check `right` does not apply for IntervalIndex.
44+
if labels is not None and labels is not False:
45+
raise NotImplementedError(
46+
"The 'labels' parameter must be either False or None. "
47+
"Please provide a valid value for 'labels'."
48+
)
4849

49-
if isinstance(bins, typing.Iterable):
50+
if isinstance(bins, int):
51+
if bins <= 0:
52+
raise ValueError("`bins` should be a positive integer.")
53+
op = agg_ops.CutOp(bins, right=right, labels=labels)
54+
return x._apply_window_op(op, window_spec=window_specs.unbound())
55+
elif isinstance(bins, typing.Iterable):
5056
if isinstance(bins, pd.IntervalIndex):
51-
# TODO: test an empty internval index
5257
as_index: pd.IntervalIndex = bins
5358
bins = tuple((bin.left.item(), bin.right.item()) for bin in bins)
5459
# To maintain consistency with pandas' behavior
5560
right = True
5661
elif len(list(bins)) == 0:
57-
raise ValueError("`bins` iterable should have at least one item")
62+
as_index = pd.IntervalIndex.from_tuples(list(bins))
63+
bins = tuple()
5864
elif isinstance(list(bins)[0], tuple):
5965
as_index = pd.IntervalIndex.from_tuples(list(bins))
6066
bins = tuple(bins)
6167
# To maintain consistency with pandas' behavior
6268
right = True
6369
elif pd.api.types.is_number(list(bins)[0]):
6470
bins_list = list(bins)
65-
if len(bins_list) < 2:
66-
raise ValueError(
67-
"`bins` iterable of numeric breaks should have"
68-
" at least two items"
69-
)
7071
as_index = pd.IntervalIndex.from_breaks(bins_list)
7172
single_type = all([isinstance(n, type(bins_list[0])) for n in bins_list])
7273
numeric_type = type(bins_list[0]) if single_type else float
@@ -77,21 +78,20 @@ def cut(
7778
]
7879
)
7980
else:
80-
raise ValueError("`bins` iterable should contain tuples or numerics")
81+
raise ValueError("`bins` iterable should contain tuples or numerics.")
8182

8283
if as_index.is_overlapping:
8384
raise ValueError("Overlapping IntervalIndex is not accepted.")
84-
85-
if labels is not None and labels is not False:
86-
raise NotImplementedError(
87-
"The 'labels' parameter must be either False or None. "
88-
"Please provide a valid value for 'labels'."
89-
)
90-
91-
return x._apply_window_op(
92-
agg_ops.CutOp(bins, right=right, labels=labels),
93-
window_spec=window_specs.unbound(),
94-
)
85+
elif len(as_index) == 0:
86+
op = agg_ops.CutOp(bins, right=right, labels=labels)
87+
return bigframes.series.Series(
88+
[pd.NA] * len(x), dtype=op.output_type(), name=x.name
89+
)
90+
else:
91+
op = agg_ops.CutOp(bins, right=right, labels=labels)
92+
return x._apply_window_op(op, window_spec=window_specs.unbound())
93+
else:
94+
raise ValueError("`bins` must be an integer or interable.")
9595

9696

9797
cut.__doc__ = vendored_pandas_tile.cut.__doc__

bigframes/operations/aggregations.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,12 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
351351
return dtypes.INT_DTYPE
352352
else:
353353
# Assumption: buckets use same numeric type
354-
interval_dtype = (
355-
pa.float64()
356-
if isinstance(self.bins, int)
357-
else dtypes.infer_literal_arrow_type(list(self.bins)[0][0])
358-
)
354+
if isinstance(self.bins, int):
355+
interval_dtype = pa.float64()
356+
elif len(list(self.bins)) == 0:
357+
interval_dtype = pa.int64()
358+
else:
359+
interval_dtype = dtypes.infer_literal_arrow_type(list(self.bins)[0][0])
359360
pa_type = pa.struct(
360361
[
361362
pa.field(

tests/system/small/test_pandas.py

Lines changed: 46 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,31 @@ def test_merge_series(scalars_dfs, merge_how):
387387
assert_pandas_df_equal(bf_result, pd_result, ignore_order=True)
388388

389389

390+
def _convert_pandas_category(pd_s: pd.Series):
391+
if not isinstance(pd_s.dtype, pd.CategoricalDtype):
392+
raise ValueError("Input must be a pandas Series with categorical data.")
393+
394+
if len(pd_s.dtype.categories) == 0:
395+
return pd.Series([pd.NA] * len(pd_s), name=pd_s.name)
396+
397+
pd_interval: pd.IntervalIndex = pd_s.cat.categories[pd_s.cat.codes] # type: ignore
398+
if pd_interval.closed == "left":
399+
left_key = "left_inclusive"
400+
right_key = "right_exclusive"
401+
else:
402+
left_key = "left_exclusive"
403+
right_key = "right_inclusive"
404+
return pd.Series(
405+
[
406+
{left_key: interval.left, right_key: interval.right}
407+
if pd.notna(val)
408+
else pd.NA
409+
for val, interval in zip(pd_s, pd_interval)
410+
],
411+
name=pd_s.name,
412+
)
413+
414+
390415
@pytest.mark.parametrize(
391416
("right"),
392417
[
@@ -420,23 +445,7 @@ def test_cut_default_labels(scalars_dfs, right):
420445
bf_result = bpd.cut(scalars_df["float64_col"], 5, right=right).to_pandas()
421446

422447
# Convert to match data format
423-
pd_interval = pd_result.cat.categories[pd_result.cat.codes]
424-
if pd_interval.closed == "left":
425-
left_key = "left_inclusive"
426-
right_key = "right_exclusive"
427-
else:
428-
left_key = "left_exclusive"
429-
right_key = "right_inclusive"
430-
pd_result_converted = pd.Series(
431-
[
432-
{left_key: interval.left, right_key: interval.right}
433-
if pd.notna(val)
434-
else pd.NA
435-
for val, interval in zip(pd_result, pd_interval)
436-
],
437-
name=pd_result.name,
438-
)
439-
448+
pd_result_converted = _convert_pandas_category(pd_result)
440449
pd.testing.assert_series_equal(
441450
bf_result, pd_result_converted, check_index=False, check_dtype=False
442451
)
@@ -458,47 +467,36 @@ def test_cut_numeric_breaks(scalars_dfs, breaks, right):
458467
bf_result = bpd.cut(scalars_df["float64_col"], breaks, right=right).to_pandas()
459468

460469
# Convert to match data format
461-
pd_interval = pd_result.cat.categories[pd_result.cat.codes]
462-
if pd_interval.closed == "left":
463-
left_key = "left_inclusive"
464-
right_key = "right_exclusive"
465-
else:
466-
left_key = "left_exclusive"
467-
right_key = "right_inclusive"
468-
469-
pd_result_converted = pd.Series(
470-
[
471-
{left_key: interval.left, right_key: interval.right}
472-
if pd.notna(val)
473-
else pd.NA
474-
for val, interval in zip(pd_result, pd_interval)
475-
],
476-
name=pd_result.name,
477-
)
470+
pd_result_converted = _convert_pandas_category(pd_result)
478471

479472
pd.testing.assert_series_equal(
480473
bf_result, pd_result_converted, check_index=False, check_dtype=False
481474
)
482475

483476

484477
@pytest.mark.parametrize(
485-
("bins",),
478+
"bins",
486479
[
487-
(-1,), # negative integer bins argument
488-
([],), # empty iterable of bins
489-
(["notabreak"],), # iterable of wrong type
490-
([1],), # numeric breaks with only one numeric
491-
# this is supported by pandas but not by
492-
# the bigquery operation and a bigframes workaround
493-
# is not yet available. Should return column
494-
# of structs with all NaN values.
480+
pytest.param([], id="empty_list"),
481+
pytest.param(
482+
[1], id="single_int_list", marks=pytest.mark.skip(reason="b/404338651")
483+
),
484+
pytest.param(pd.IntervalIndex.from_tuples([]), id="empty_interval_index"),
495485
],
496486
)
497-
def test_cut_errors(scalars_dfs, bins):
498-
scalars_df, _ = scalars_dfs
487+
def test_cut_w_edge_cases(scalars_dfs, bins):
488+
scalars_df, scalars_pandas_df = scalars_dfs
489+
bf_result = bpd.cut(scalars_df["int64_too"], bins, labels=False).to_pandas()
490+
if isinstance(bins, list):
491+
bins = pd.IntervalIndex.from_tuples(bins)
492+
pd_result = pd.cut(scalars_pandas_df["int64_too"], bins, labels=False)
493+
494+
# Convert to match data format
495+
pd_result_converted = _convert_pandas_category(pd_result)
499496

500-
with pytest.raises(ValueError):
501-
bpd.cut(scalars_df["float64_col"], bins)
497+
pd.testing.assert_series_equal(
498+
bf_result, pd_result_converted, check_index=False, check_dtype=False
499+
)
502500

503501

504502
@pytest.mark.parametrize(
@@ -529,23 +527,7 @@ def test_cut_with_interval(scalars_dfs, bins, right):
529527
pd_result = pd.cut(scalars_pandas_df["int64_too"], bins, labels=False, right=right)
530528

531529
# Convert to match data format
532-
pd_interval = pd_result.cat.categories[pd_result.cat.codes]
533-
if pd_interval.closed == "left":
534-
left_key = "left_inclusive"
535-
right_key = "right_exclusive"
536-
else:
537-
left_key = "left_exclusive"
538-
right_key = "right_inclusive"
539-
540-
pd_result_converted = pd.Series(
541-
[
542-
{left_key: interval.left, right_key: interval.right}
543-
if pd.notna(val)
544-
else pd.NA
545-
for val, interval in zip(pd_result, pd_interval)
546-
],
547-
name=pd_result.name,
548-
)
530+
pd_result_converted = _convert_pandas_category(pd_result)
549531

550532
pd.testing.assert_series_equal(
551533
bf_result, pd_result_converted, check_index=False, check_dtype=False

tests/unit/test_pandas.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,20 @@ def test_cut_raises_with_labels():
101101

102102

103103
@pytest.mark.parametrize(
104-
("bins",),
105-
(
106-
(0,),
107-
(-1,),
108-
),
104+
("bins", "error_message"),
105+
[
106+
pytest.param(1.5, "`bins` must be an integer or interable.", id="float"),
107+
pytest.param(0, "`bins` should be a positive integer.", id="zero_int"),
108+
pytest.param(-1, "`bins` should be a positive integer.", id="neg_int"),
109+
pytest.param(
110+
["notabreak"],
111+
"`bins` iterable should contain tuples or numerics",
112+
id="iterable_w_wrong_type",
113+
),
114+
],
109115
)
110-
def test_cut_raises_with_invalid_bins(bins: int):
111-
with pytest.raises(ValueError, match="`bins` should be a positive integer."):
116+
def test_cut_raises_with_invalid_bins(bins: int, error_message: str):
117+
with pytest.raises(ValueError, match=error_message):
112118
mock_series = mock.create_autospec(bigframes.pandas.Series, instance=True)
113119
bigframes.pandas.cut(mock_series, bins, labels=False)
114120

0 commit comments

Comments
 (0)