@@ -387,6 +387,31 @@ def test_merge_series(scalars_dfs, merge_how):
387
387
assert_pandas_df_equal (bf_result , pd_result , ignore_order = True )
388
388
389
389
390
+ def _convert_pandas_category (pd_s : pd .Series ):
391
+ if not isinstance (pd_s .dtype , pd .CategoricalDtype ):
392
+ raise ValueError ("Input must be a pandas Series with categorical data." )
393
+
394
+ if len (pd_s .dtype .categories ) == 0 :
395
+ return pd .Series ([pd .NA ] * len (pd_s ), name = pd_s .name )
396
+
397
+ pd_interval : pd .IntervalIndex = pd_s .cat .categories [pd_s .cat .codes ] # type: ignore
398
+ if pd_interval .closed == "left" :
399
+ left_key = "left_inclusive"
400
+ right_key = "right_exclusive"
401
+ else :
402
+ left_key = "left_exclusive"
403
+ right_key = "right_inclusive"
404
+ return pd .Series (
405
+ [
406
+ {left_key : interval .left , right_key : interval .right }
407
+ if pd .notna (val )
408
+ else pd .NA
409
+ for val , interval in zip (pd_s , pd_interval )
410
+ ],
411
+ name = pd_s .name ,
412
+ )
413
+
414
+
390
415
@pytest .mark .parametrize (
391
416
("right" ),
392
417
[
@@ -420,23 +445,7 @@ def test_cut_default_labels(scalars_dfs, right):
420
445
bf_result = bpd .cut (scalars_df ["float64_col" ], 5 , right = right ).to_pandas ()
421
446
422
447
# Convert to match data format
423
- pd_interval = pd_result .cat .categories [pd_result .cat .codes ]
424
- if pd_interval .closed == "left" :
425
- left_key = "left_inclusive"
426
- right_key = "right_exclusive"
427
- else :
428
- left_key = "left_exclusive"
429
- right_key = "right_inclusive"
430
- pd_result_converted = pd .Series (
431
- [
432
- {left_key : interval .left , right_key : interval .right }
433
- if pd .notna (val )
434
- else pd .NA
435
- for val , interval in zip (pd_result , pd_interval )
436
- ],
437
- name = pd_result .name ,
438
- )
439
-
448
+ pd_result_converted = _convert_pandas_category (pd_result )
440
449
pd .testing .assert_series_equal (
441
450
bf_result , pd_result_converted , check_index = False , check_dtype = False
442
451
)
@@ -458,47 +467,36 @@ def test_cut_numeric_breaks(scalars_dfs, breaks, right):
458
467
bf_result = bpd .cut (scalars_df ["float64_col" ], breaks , right = right ).to_pandas ()
459
468
460
469
# Convert to match data format
461
- pd_interval = pd_result .cat .categories [pd_result .cat .codes ]
462
- if pd_interval .closed == "left" :
463
- left_key = "left_inclusive"
464
- right_key = "right_exclusive"
465
- else :
466
- left_key = "left_exclusive"
467
- right_key = "right_inclusive"
468
-
469
- pd_result_converted = pd .Series (
470
- [
471
- {left_key : interval .left , right_key : interval .right }
472
- if pd .notna (val )
473
- else pd .NA
474
- for val , interval in zip (pd_result , pd_interval )
475
- ],
476
- name = pd_result .name ,
477
- )
470
+ pd_result_converted = _convert_pandas_category (pd_result )
478
471
479
472
pd .testing .assert_series_equal (
480
473
bf_result , pd_result_converted , check_index = False , check_dtype = False
481
474
)
482
475
483
476
484
477
@pytest .mark .parametrize (
485
- ( "bins" ,) ,
478
+ "bins" ,
486
479
[
487
- (- 1 ,), # negative integer bins argument
488
- ([],), # empty iterable of bins
489
- (["notabreak" ],), # iterable of wrong type
490
- ([1 ],), # numeric breaks with only one numeric
491
- # this is supported by pandas but not by
492
- # the bigquery operation and a bigframes workaround
493
- # is not yet available. Should return column
494
- # of structs with all NaN values.
480
+ pytest .param ([], id = "empty_list" ),
481
+ pytest .param (
482
+ [1 ], id = "single_int_list" , marks = pytest .mark .skip (reason = "b/404338651" )
483
+ ),
484
+ pytest .param (pd .IntervalIndex .from_tuples ([]), id = "empty_interval_index" ),
495
485
],
496
486
)
497
- def test_cut_errors (scalars_dfs , bins ):
498
- scalars_df , _ = scalars_dfs
487
+ def test_cut_w_edge_cases (scalars_dfs , bins ):
488
+ scalars_df , scalars_pandas_df = scalars_dfs
489
+ bf_result = bpd .cut (scalars_df ["int64_too" ], bins , labels = False ).to_pandas ()
490
+ if isinstance (bins , list ):
491
+ bins = pd .IntervalIndex .from_tuples (bins )
492
+ pd_result = pd .cut (scalars_pandas_df ["int64_too" ], bins , labels = False )
493
+
494
+ # Convert to match data format
495
+ pd_result_converted = _convert_pandas_category (pd_result )
499
496
500
- with pytest .raises (ValueError ):
501
- bpd .cut (scalars_df ["float64_col" ], bins )
497
+ pd .testing .assert_series_equal (
498
+ bf_result , pd_result_converted , check_index = False , check_dtype = False
499
+ )
502
500
503
501
504
502
@pytest .mark .parametrize (
@@ -529,23 +527,7 @@ def test_cut_with_interval(scalars_dfs, bins, right):
529
527
pd_result = pd .cut (scalars_pandas_df ["int64_too" ], bins , labels = False , right = right )
530
528
531
529
# Convert to match data format
532
- pd_interval = pd_result .cat .categories [pd_result .cat .codes ]
533
- if pd_interval .closed == "left" :
534
- left_key = "left_inclusive"
535
- right_key = "right_exclusive"
536
- else :
537
- left_key = "left_exclusive"
538
- right_key = "right_inclusive"
539
-
540
- pd_result_converted = pd .Series (
541
- [
542
- {left_key : interval .left , right_key : interval .right }
543
- if pd .notna (val )
544
- else pd .NA
545
- for val , interval in zip (pd_result , pd_interval )
546
- ],
547
- name = pd_result .name ,
548
- )
530
+ pd_result_converted = _convert_pandas_category (pd_result )
549
531
550
532
pd .testing .assert_series_equal (
551
533
bf_result , pd_result_converted , check_index = False , check_dtype = False
0 commit comments