Skip to content

Commit 1f3b514

Browse files
authored
FIX-#2362: fix handling slices in 'DataFrame.__setitem__' (#2741)
Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
1 parent 61c6e99 commit 1f3b514

File tree

5 files changed

+54
-27
lines changed

5 files changed

+54
-27
lines changed

modin/pandas/base.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2708,7 +2708,21 @@ def __getitem__(self, key):
27082708
else:
27092709
return self._getitem(key)
27102710

2711-
def _getitem_slice(self, key):
2711+
def _setitem_slice(self, key: slice, value):
2712+
"""
2713+
Set rows specified by 'key' slice with 'value'.
2714+
2715+
Parameters
2716+
----------
2717+
key: location or index based slice,
2718+
Key that points rows to modify.
2719+
value: any,
2720+
Value to assing to the rows.
2721+
"""
2722+
indexer = convert_to_index_sliceable(pandas.DataFrame(index=self.index), key)
2723+
self.iloc[indexer] = value
2724+
2725+
def _getitem_slice(self, key: slice):
27122726
if key.start is None and key.stop is None:
27132727
return self.copy()
27142728
return self.iloc[key]

modin/pandas/dataframe.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2008,6 +2008,9 @@ def __setattr__(self, key, value):
20082008
object.__setattr__(self, key, value)
20092009

20102010
def __setitem__(self, key, value):
2011+
if isinstance(key, slice):
2012+
return self._setitem_slice(key, value)
2013+
20112014
if hashable(key) and key not in self.columns:
20122015
if isinstance(value, Series) and len(self.columns) == 0:
20132016
self._query_compiler = value._query_compiler.copy()
@@ -2038,24 +2041,23 @@ def __setitem__(self, key, value):
20382041
self.insert(loc=len(self.columns), column=key, value=value)
20392042
return
20402043

2041-
if not isinstance(key, str):
2042-
2044+
if not hashable(key):
20432045
if isinstance(key, DataFrame) or isinstance(key, np.ndarray):
20442046
if isinstance(key, np.ndarray):
20452047
if key.shape != self.shape:
20462048
raise ValueError("Array must be same shape as DataFrame")
20472049
key = DataFrame(key, columns=self.columns)
20482050
return self.mask(key, value, inplace=True)
20492051

2050-
def setitem_without_string_columns(df):
2052+
def setitem_unhashable_key(df):
20512053
# Arrow makes memory-mapped objects immutable, so copy will allow them
20522054
# to be mutable again.
20532055
df = df.copy(True)
20542056
df[key] = value
20552057
return df
20562058

20572059
return self._update_inplace(
2058-
self._default_to_pandas(setitem_without_string_columns)._query_compiler
2060+
self._default_to_pandas(setitem_unhashable_key)._query_compiler
20592061
)
20602062
if is_list_like(value):
20612063
if isinstance(value, (pandas.DataFrame, DataFrame)):

modin/pandas/series.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,8 @@ def __round__(self, decimals=0):
285285
)
286286

287287
def __setitem__(self, key, value):
288-
if isinstance(key, slice) and (
289-
isinstance(key.start, int) or isinstance(key.stop, int)
290-
):
291-
# There could be two type of slices:
292-
# - Location based slice (1:5)
293-
# - Labels based slice ("a":"e")
294-
# For location based slice we're going to `iloc`, since `loc` can't manage it.
295-
self.iloc[key] = value
288+
if isinstance(key, slice):
289+
self._setitem_slice(key, value)
296290
else:
297291
self.loc[key] = value
298292

modin/pandas/test/dataframe/test_indexing.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,19 +1215,16 @@ def test___setitem__(data):
12151215
df_equals(modin_df, pandas_df)
12161216
assert isinstance(modin_df["new_col"][0], type(pandas_df["new_col"][0]))
12171217

1218+
modin_df[1:5] = 10
1219+
pandas_df[1:5] = 10
1220+
df_equals(modin_df, pandas_df)
1221+
12181222
# Transpose test
12191223
modin_df = pd.DataFrame(data).T
12201224
pandas_df = pandas.DataFrame(data).T
12211225

1222-
# We default to pandas on non-string column names
1223-
if not all(isinstance(c, str) for c in modin_df.columns):
1224-
with pytest.warns(UserWarning):
1225-
modin_df[modin_df.columns[0]] = 0
1226-
else:
1227-
modin_df[modin_df.columns[0]] = 0
1228-
1226+
modin_df[modin_df.columns[0]] = 0
12291227
pandas_df[pandas_df.columns[0]] = 0
1230-
12311228
df_equals(modin_df, pandas_df)
12321229

12331230
modin_df.columns = [str(i) for i in modin_df.columns]
@@ -1240,7 +1237,10 @@ def test___setitem__(data):
12401237

12411238
modin_df[modin_df.columns[0]][modin_df.index[0]] = 12345
12421239
pandas_df[pandas_df.columns[0]][pandas_df.index[0]] = 12345
1240+
df_equals(modin_df, pandas_df)
12431241

1242+
modin_df[1:5] = 10
1243+
pandas_df[1:5] = 10
12441244
df_equals(modin_df, pandas_df)
12451245

12461246

modin/pandas/test/test_series.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -529,14 +529,31 @@ def test___setitem__(data):
529529
@pytest.mark.parametrize(
530530
"key",
531531
[
532-
pytest.param(slice(1, 3), id="numeric_slice"),
533-
pytest.param(slice("a", "c"), id="index_based_slice"),
534-
pytest.param(["a", "c", "e"], id="list_of_labels"),
535-
pytest.param([True, False, True, False, True], id="boolean_mask"),
532+
pytest.param(lambda idx: slice(1, 3), id="location_based_slice"),
533+
pytest.param(lambda idx: slice(idx[1], idx[-1]), id="index_based_slice"),
534+
pytest.param(lambda idx: [idx[0], idx[2], idx[-1]], id="list_of_labels"),
535+
pytest.param(
536+
lambda idx: [True if i % 2 else False for i in range(len(idx))],
537+
id="boolean_mask",
538+
),
536539
],
537540
)
538-
def test___setitem___non_hashable(key):
539-
md_sr, pd_sr = create_test_series([1, 2, 3, 4, 5], index=["a", "b", "c", "d", "e"])
541+
@pytest.mark.parametrize(
542+
"index",
543+
[
544+
pytest.param(
545+
lambda idx_len: [chr(x) for x in range(ord("a"), ord("a") + idx_len)],
546+
id="str_index",
547+
),
548+
pytest.param(lambda idx_len: list(range(1, idx_len + 1)), id="int_index"),
549+
],
550+
)
551+
def test___setitem___non_hashable(key, index):
552+
data = np.arange(5)
553+
index = index(len(data))
554+
key = key(index)
555+
md_sr, pd_sr = create_test_series(data, index=index)
556+
540557
md_sr[key] = 10
541558
pd_sr[key] = 10
542559
df_equals(md_sr, pd_sr)

0 commit comments

Comments
 (0)