Skip to content

Commit

Permalink
BUG: Series[Period][mask] = 'foo' raising inconsistent with non-mask …
Browse files Browse the repository at this point in the history
…indexing (#45768)
  • Loading branch information
jbrockmendel authored Feb 2, 2022
1 parent 06dac44 commit be8d1ec
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 29 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ Indexing
- Bug in :meth:`DataFrame.iloc` where indexing a single row on a :class:`DataFrame` with a single ExtensionDtype column gave a copy instead of a view on the underlying data (:issue:`45241`)
- Bug in setting a NA value (``None`` or ``np.nan``) into a :class:`Series` with int-based :class:`IntervalDtype` incorrectly casting to object dtype instead of a float-based :class:`IntervalDtype` (:issue:`45568`)
- Bug in :meth:`Series.__setitem__` with a non-integer :class:`Index` when using an integer key to set a value that cannot be set inplace where a ``ValueError`` was raised insead of casting to a common dtype (:issue:`45070`)
- Bug in :meth:`Series.__setitem__` when setting incompatible values into a ``PeriodDtype`` or ``IntervalDtype`` :class:`Series` raising when indexing with a boolean mask but coercing when indexing with otherwise-equivalent indexers; these now consistently coerce, along with :meth:`Series.mask` and :meth:`Series.where` (:issue:`45768`)
- Bug in :meth:`Series.loc.__setitem__` and :meth:`Series.loc.__getitem__` not raising when using multiple keys without using a :class:`MultiIndex` (:issue:`13831`)
- Bug when setting a value too large for a :class:`Series` dtype failing to coerce to a common type (:issue:`26049`, :issue:`32878`)
- Bug in :meth:`loc.__setitem__` treating ``range`` keys as positional instead of label-based (:issue:`45479`)
Expand Down
35 changes: 13 additions & 22 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,8 @@ def where(self, other, cond) -> list[Block]:

cond = extract_bool_array(cond)

orig_other = other
orig_cond = cond
other = self._maybe_squeeze_arg(other)
cond = self._maybe_squeeze_arg(cond)

Expand All @@ -1395,21 +1397,15 @@ def where(self, other, cond) -> list[Block]:

if is_interval_dtype(self.dtype):
# TestSetitemFloatIntervalWithIntIntervalValues
blk = self.coerce_to_target_dtype(other)
if blk.dtype == _dtype_obj:
# For now at least only support casting e.g.
# Interval[int64]->Interval[float64]
raise
return blk.where(other, cond)
blk = self.coerce_to_target_dtype(orig_other)
nbs = blk.where(orig_other, orig_cond)
return self._maybe_downcast(nbs, "infer")

elif isinstance(self, NDArrayBackedExtensionBlock):
# NB: not (yet) the same as
# isinstance(values, NDArrayBackedExtensionArray)
if isinstance(self.dtype, PeriodDtype):
# TODO: don't special-case
raise
blk = self.coerce_to_target_dtype(other)
nbs = blk.where(other, cond)
blk = self.coerce_to_target_dtype(orig_other)
nbs = blk.where(orig_other, orig_cond)
return self._maybe_downcast(nbs, "infer")

else:
Expand All @@ -1426,6 +1422,8 @@ def putmask(self, mask, new) -> list[Block]:

values = self.values

orig_new = new
orig_mask = mask
new = self._maybe_squeeze_arg(new)
mask = self._maybe_squeeze_arg(mask)

Expand All @@ -1438,21 +1436,14 @@ def putmask(self, mask, new) -> list[Block]:
if is_interval_dtype(self.dtype):
# Discussion about what we want to support in the general
# case GH#39584
blk = self.coerce_to_target_dtype(new)
if blk.dtype == _dtype_obj:
# For now at least, only support casting e.g.
# Interval[int64]->Interval[float64],
raise
return blk.putmask(mask, new)
blk = self.coerce_to_target_dtype(orig_new)
return blk.putmask(orig_mask, orig_new)

elif isinstance(self, NDArrayBackedExtensionBlock):
# NB: not (yet) the same as
# isinstance(values, NDArrayBackedExtensionArray)
if isinstance(self.dtype, PeriodDtype):
# TODO: don't special-case
raise
blk = self.coerce_to_target_dtype(new)
return blk.putmask(mask, new)
blk = self.coerce_to_target_dtype(orig_new)
return blk.putmask(orig_mask, orig_new)

else:
raise
Expand Down
8 changes: 7 additions & 1 deletion pandas/tests/arrays/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,16 @@ def test_set_closed(self, closed, new_closed):
],
)
def test_where_raises(self, other):
# GH#45768 The IntervalArray methods raises; the Series method coerces
ser = pd.Series(IntervalArray.from_breaks([1, 2, 3, 4], closed="left"))
mask = np.array([True, False, True])
match = "'value.closed' is 'right', expected 'left'."
with pytest.raises(ValueError, match=match):
ser.where([True, False, True], other=other)
ser.array._where(mask, other)

res = ser.where(mask, other=other)
expected = ser.astype(object).where(mask, other)
tm.assert_series_equal(res, expected)

def test_shift(self):
# https://github.com/pandas-dev/pandas/issues/31495, GH#22428, GH#31502
Expand Down
8 changes: 7 additions & 1 deletion pandas/tests/arrays/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,16 @@ def test_sub_period():
[pd.Period("2000", freq="H"), period_array(["2000", "2001", "2000"], freq="H")],
)
def test_where_different_freq_raises(other):
# GH#45768 The PeriodArray method raises, the Series method coerces
ser = pd.Series(period_array(["2000", "2001", "2002"], freq="D"))
cond = np.array([True, False, True])

with pytest.raises(IncompatibleFrequency, match="freq"):
ser.where(cond, other)
ser.array._where(cond, other)

res = ser.where(cond, other)
expected = ser.astype(object).where(cond, other)
tm.assert_series_equal(res, expected)


# ----------------------------------------------------------------------------
Expand Down
49 changes: 44 additions & 5 deletions pandas/tests/frame/indexing/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,20 @@ def test_where_interval_noop(self):
res = ser.where(ser.notna())
tm.assert_series_equal(res, ser)

def test_where_interval_fullop_downcast(self, frame_or_series):
# GH#45768
obj = frame_or_series([pd.Interval(0, 0)] * 2)
other = frame_or_series([1.0, 2.0])
res = obj.where(~obj.notna(), other)

# since all entries are being changed, we will downcast result
# from object to ints (not floats)
tm.assert_equal(res, other.astype(np.int64))

# unlike where, Block.putmask does not downcast
obj.mask(obj.notna(), other, inplace=True)
tm.assert_equal(obj, other.astype(object))

@pytest.mark.parametrize(
"dtype",
[
Expand Down Expand Up @@ -736,6 +750,16 @@ def test_where_datetimelike_noop(self, dtype):
res4 = df.mask(mask2, "foo")
tm.assert_frame_equal(res4, df)

# opposite case where we are replacing *all* values -> we downcast
# from object dtype # GH#45768
res5 = df.where(mask2, 4)
expected = DataFrame(4, index=df.index, columns=df.columns)
tm.assert_frame_equal(res5, expected)

# unlike where, Block.putmask does not downcast
df.mask(~mask2, 4, inplace=True)
tm.assert_frame_equal(df, expected.astype(object))


def test_where_try_cast_deprecated(frame_or_series):
obj = DataFrame(np.random.randn(4, 3))
Expand Down Expand Up @@ -894,14 +918,29 @@ def test_where_period_invalid_na(frame_or_series, as_cat, request):
else:
msg = "value should be a 'Period'"

with pytest.raises(TypeError, match=msg):
obj.where(mask, tdnat)
if as_cat:
with pytest.raises(TypeError, match=msg):
obj.where(mask, tdnat)

with pytest.raises(TypeError, match=msg):
obj.mask(mask, tdnat)
with pytest.raises(TypeError, match=msg):
obj.mask(mask, tdnat)

with pytest.raises(TypeError, match=msg):
obj.mask(mask, tdnat, inplace=True)

else:
# With PeriodDtype, ser[i] = tdnat coerces instead of raising,
# so for consistency, ser[mask] = tdnat must as well
expected = obj.astype(object).where(mask, tdnat)
result = obj.where(mask, tdnat)
tm.assert_equal(result, expected)

expected = obj.astype(object).mask(mask, tdnat)
result = obj.mask(mask, tdnat)
tm.assert_equal(result, expected)

with pytest.raises(TypeError, match=msg):
obj.mask(mask, tdnat, inplace=True)
tm.assert_equal(obj, expected)


def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype):
Expand Down
17 changes: 17 additions & 0 deletions pandas/tests/series/indexing/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
IntervalIndex,
MultiIndex,
NaT,
Period,
Series,
Timedelta,
Timestamp,
Expand Down Expand Up @@ -1317,6 +1318,22 @@ def obj(self):
return Series(timedelta_range("1 day", periods=4))


@pytest.mark.parametrize(
"val", ["foo", Period("2016", freq="Y"), Interval(1, 2, closed="both")]
)
@pytest.mark.parametrize("exp_dtype", [object])
class TestPeriodIntervalCoercion(CoercionTest):
# GH#45768
@pytest.fixture(
params=[
period_range("2016-01-01", periods=3, freq="D"),
interval_range(1, 5),
]
)
def obj(self, request):
return Series(request.param)


def test_20643():
# closed by GH#45121
orig = Series([0, 1, 2], index=["a", "b", "c"])
Expand Down

0 comments on commit be8d1ec

Please sign in to comment.