Skip to content

Commit e31d06a

Browse files
committed
add default round method
1 parent 15915dc commit e31d06a

File tree

5 files changed

+18
-16
lines changed

5 files changed

+18
-16
lines changed

pandas/core/arrays/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2343,7 +2343,17 @@ def _add_logical_ops(cls) -> None:
23432343
setattr(cls, "__rxor__", cls._create_logical_method(roperator.rxor))
23442344

23452345
def round(self, decimals: int = 0, *args, **kwargs) -> Self:
2346-
raise AbstractMethodError(self)
2346+
# Implementer note: This is a non-optimized default implementation.
2347+
# Implementers are encouraged to override this method to avoid
2348+
# elementwise rounding.
2349+
if not self.dtype._is_numeric or self.dtype._is_boolean:
2350+
raise TypeError(
2351+
f"Cannot round {self.dtype} dtype as it is non-numeric or boolean"
2352+
)
2353+
return self._from_sequence(
2354+
[round(element) if not isna(element) else element for element in self.data],
2355+
dtype=self.dtype,
2356+
)
23472357

23482358

23492359
class ExtensionScalarOpsMixin(ExtensionOpsMixin):

pandas/tests/extension/base/methods.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -700,11 +700,11 @@ def test_equals_same_data_different_object(self, data):
700700
assert pd.Series(data).equals(pd.Series(data))
701701

702702
def test_round(self, data):
703-
if not data.dtype._is_numeric:
704-
pytest.skip("Round is only valid for numeric dtypes")
703+
if not data.dtype._is_numeric or data.dtype._is_boolean:
704+
pytest.skip("Round is only valid for numeric non-boolean dtypes")
705705
result = pd.Series(data).round()
706706
expected = pd.Series(
707-
[np.round(element) if pd.notna(element) else element for element in data],
707+
[round(element) if pd.notna(element) else element for element in data],
708708
dtype=data.dtype,
709709
)
710710
tm.assert_series_equal(result, expected)

pandas/tests/extension/test_arrow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,8 +1017,10 @@ def test_round(self, data, request):
10171017
# raises=pa.ArrowInvalid,
10181018
reason="ArrowArray.round converts dtype to double",
10191019
)
1020-
if pa.types.is_float32(data.dtype.pyarrow_dtype) or pa.types.is_float64(
1021-
data.dtype.pyarrow_dtype
1020+
if (
1021+
pa.types.is_float32(data.dtype.pyarrow_dtype)
1022+
or pa.types.is_float64(data.dtype.pyarrow_dtype)
1023+
or pa.types.is_decimal(data.dtype.pyarrow_dtype)
10221024
):
10231025
mark = None
10241026
if mark is not None:

pandas/tests/extension/test_interval.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ def test_EA_types(self, engine, data):
8989
with pytest.raises(NotImplementedError, match=expected_msg):
9090
super().test_EA_types(engine, data)
9191

92-
@pytest.mark.xfail(reason="Round is not valid for IntervalArray.")
93-
def test_round(self, data):
94-
super().test_round(data)
95-
9692

9793
# TODO: either belongs in tests.arrays.interval or move into base tests.
9894
def test_fillna_non_scalar_raises(data_missing):

pandas/tests/extension/test_masked.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,6 @@ def test_invert(self, data, request):
359359
request.node.add_marker(mark)
360360
super().test_invert(data)
361361

362-
def test_round(self, data, request):
363-
if data.dtype == "boolean":
364-
mark = pytest.mark.xfail(reason="Cannot round boolean dtype")
365-
request.node.add_marker(mark)
366-
super().test_round(data)
367-
368362

369363
class Test2DCompat(base.Dim2CompatTests):
370364
pass

0 commit comments

Comments
 (0)