Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix binary operations on attrs for Series and DataFrame #59636

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,6 +1454,8 @@ def _duplicated(self, keep: DropKeep = "first") -> npt.NDArray[np.bool_]:
return algorithms.duplicated(arr, keep=keep)

def _arith_method(self, other, op):
if not getattr(self, "attrs", None) and getattr(other, "attrs", None):
self.attrs = other.attrs
res_name = ops.get_op_result_name(self, other)

lvalues = self._values
Expand Down
12 changes: 12 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7814,13 +7814,19 @@ class diet
def _cmp_method(self, other, op):
axis: Literal[1] = 1 # only relevant for Series other case

if not getattr(self, "attrs", None) and getattr(other, "attrs", None):
self.attrs = other.attrs

self, other = self._align_for_op(other, axis, flex=False, level=None)

# See GH#4537 for discussion of scalar op behavior
new_data = self._dispatch_frame_op(other, op, axis=axis)
return self._construct_result(new_data)

def _arith_method(self, other, op):
if not getattr(self, "attrs", None) and getattr(other, "attrs", None):
self.attrs = other.attrs

if self._should_reindex_frame_op(other, op, 1, None, None):
return self._arith_method_with_reindex(other, op)

Expand Down Expand Up @@ -8186,6 +8192,9 @@ def _flex_arith_method(

new_data = self._dispatch_frame_op(other, op)

if not getattr(self, "attrs", None) and getattr(other, "attrs", None):
self.attrs = other.attrs

return self._construct_result(new_data)

def _construct_result(self, result) -> DataFrame:
Expand Down Expand Up @@ -8224,6 +8233,9 @@ def _flex_cmp_method(self, other, op, *, axis: Axis = "columns", level=None):

self, other = self._align_for_op(other, axis, flex=True, level=level)

if not getattr(self, "attrs", None) and getattr(other, "attrs", None):
self.attrs = other.attrs

new_data = self._dispatch_frame_op(other, op, axis=axis)
return self._construct_result(new_data)

Expand Down
9 changes: 9 additions & 0 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5843,6 +5843,9 @@ def to_period(
def _cmp_method(self, other, op):
res_name = ops.get_op_result_name(self, other)

if not getattr(self, "attrs", None) and getattr(other, "attrs", None):
self.attrs = other.attrs

if isinstance(other, Series) and not self._indexed_same(other):
raise ValueError("Can only compare identically-labeled Series objects")

Expand All @@ -5854,6 +5857,8 @@ def _cmp_method(self, other, op):
return self._construct_result(res_values, name=res_name)

def _logical_method(self, other, op):
if not getattr(self, "attrs", None) and getattr(other, "attrs", None):
self.attrs = other.attrs
res_name = ops.get_op_result_name(self, other)
self, other = self._align_for_op(other, align_asobject=True)

Expand Down Expand Up @@ -5923,6 +5928,10 @@ def _binop(self, other: Series, func, level=None, fill_value=None) -> Series:
result = func(this_vals, other_vals)

name = ops.get_op_result_name(self, other)

if not getattr(this, "attrs", None) and getattr(other, "attrs", None):
this.attrs = other.attrs

out = this._construct_result(result, name)
return cast(Series, out)

Expand Down
16 changes: 16 additions & 0 deletions pandas/tests/frame/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,22 @@ def test_attrs_deepcopy(self):
assert result.attrs == df.attrs
assert result.attrs["tags"] is not df.attrs["tags"]

def test_attrs_binary_operations(self):
# GH 51607
df_1 = DataFrame({"A": [2, 3]})
df_2 = DataFrame({"A": [-3, 9]})
attrs = {"info": "DataFrame"}
df_1.attrs = attrs
assert (df_1 + df_2).attrs == attrs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than doing this you can just use the all_binary_operators fixture from conftest.py (I think)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the change.

assert (df_2 + df_1).attrs == attrs
assert (df_2 - df_1).attrs == attrs
assert (df_2 / df_1).attrs == attrs
assert (df_2 * df_1).attrs == attrs
assert (df_2.add(df_1)).attrs == attrs
assert (df_2.sub(df_1)).attrs == attrs
assert (df_2.div(df_1)).attrs == attrs
assert (df_2.mul(df_1)).attrs == attrs

@pytest.mark.parametrize("allows_duplicate_labels", [True, False, None])
def test_set_flags(
self,
Expand Down
47 changes: 0 additions & 47 deletions pandas/tests/generic/test_finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,53 +427,6 @@ def test_binops(request, args, annotate, all_binary_operators):
if annotate == "right" and isinstance(right, int):
pytest.skip("right is an int and doesn't support .attrs")

if not (isinstance(left, int) or isinstance(right, int)) and annotate != "both":
if not all_binary_operators.__name__.startswith("r"):
if annotate == "right" and isinstance(left, type(right)):
request.applymarker(
pytest.mark.xfail(
reason=f"{all_binary_operators} doesn't work when right has "
f"attrs and both are {type(left)}"
)
)
if not isinstance(left, type(right)):
if annotate == "left" and isinstance(left, pd.Series):
request.applymarker(
pytest.mark.xfail(
reason=f"{all_binary_operators} doesn't work when the "
"objects are different Series has attrs"
)
)
elif annotate == "right" and isinstance(right, pd.Series):
request.applymarker(
pytest.mark.xfail(
reason=f"{all_binary_operators} doesn't work when the "
"objects are different Series has attrs"
)
)
else:
if annotate == "left" and isinstance(left, type(right)):
request.applymarker(
pytest.mark.xfail(
reason=f"{all_binary_operators} doesn't work when left has "
f"attrs and both are {type(left)}"
)
)
if not isinstance(left, type(right)):
if annotate == "right" and isinstance(right, pd.Series):
request.applymarker(
pytest.mark.xfail(
reason=f"{all_binary_operators} doesn't work when the "
"objects are different Series has attrs"
)
)
elif annotate == "left" and isinstance(left, pd.Series):
request.applymarker(
pytest.mark.xfail(
reason=f"{all_binary_operators} doesn't work when the "
"objects are different Series has attrs"
)
)
if annotate in {"left", "both"} and not isinstance(left, int):
left.attrs = {"a": 1}
if annotate in {"right", "both"} and not isinstance(right, int):
Expand Down
16 changes: 16 additions & 0 deletions pandas/tests/series/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,22 @@ def test_attrs(self):
result = s + 1
assert result.attrs == {"version": 1}

def test_attrs_binary_operations(self):
# GH 51607
s1 = Series([2, 5])
s2 = Series([7, -1])
attrs = {"info": "Series"}
s1.attrs = attrs
assert (s1 + s2).attrs == attrs
assert (s2 + s1).attrs == attrs
assert (s2 - s1).attrs == attrs
assert (s2 / s1).attrs == attrs
assert (s2 * s1).attrs == attrs
assert (s2.add(s1)).attrs == attrs
assert (s2.sub(s1)).attrs == attrs
assert (s2.div(s1)).attrs == attrs
assert (s2.mul(s1)).attrs == attrs

@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
Expand Down
Loading