Skip to content

Commit 5c5a8e0

Browse files
authored
FIX-#2253: loc assignment fixed in case of (1, 1) shape frame (#2316)
Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
1 parent a7d3093 commit 5c5a8e0

File tree

2 files changed

+92
-14
lines changed

2 files changed

+92
-14
lines changed

modin/pandas/indexing.py

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ def is_slice(x):
6060
return isinstance(x, slice)
6161

6262

63+
def compute_sliced_len(slc, sequence_len):
64+
"""
65+
Compute length of sliced object.
66+
67+
Parameters
68+
----------
69+
slc: slice
70+
Slice object
71+
sequence_len: int
72+
Length of sequence, to which slice will be applied
73+
74+
Returns
75+
-------
76+
int
77+
Length of object after applying slice object on it.
78+
"""
79+
# This will translate slice to a range, from which we can retrieve length
80+
return len(range(*slc.indices(sequence_len)))
81+
82+
6383
def is_2d(x):
6484
"""
6585
Implement [METHOD_NAME].
@@ -293,7 +313,7 @@ def __getitem__(self, row_lookup, col_lookup, ndim):
293313
)
294314
return self.df.__constructor__(query_compiler=qc_view).squeeze(axis=axis)
295315

296-
def __setitem__(self, row_lookup, col_lookup, item):
316+
def __setitem__(self, row_lookup, col_lookup, item, axis=None):
297317
"""
298318
Implement [METHOD_NAME].
299319
@@ -317,15 +337,11 @@ def __setitem__(self, row_lookup, col_lookup, item):
317337
col_lookup = range(len(self.qc.columns))[col_lookup]
318338
# This is True when we dealing with assignment of a full column. This case
319339
# should be handled in a fastpath with `df[col] = item`.
320-
if (
321-
len(row_lookup) == len(self.qc.index)
322-
and len(col_lookup) == 1
323-
and hasattr(self.df, "columns")
324-
):
340+
if axis == 0:
325341
self.df[self.df.columns[col_lookup][0]] = item
326342
# This is True when we are assigning to a full row. We want to reuse the setitem
327343
# mechanism to operate along only one axis for performance reasons.
328-
elif len(col_lookup) == len(self.qc.columns) and len(row_lookup) == 1:
344+
elif axis == 1:
329345
if hasattr(item, "_query_compiler"):
330346
item = item._query_compiler
331347
new_qc = self.qc.setitem(1, self.qc.index[row_lookup[0]], item)
@@ -417,6 +433,57 @@ def _write_items(self, row_lookup, col_lookup, item):
417433
new_qc = self.qc.write_items(row_lookup, col_lookup, item)
418434
self.df._create_or_update_from_compiler(new_qc, inplace=True)
419435

436+
def _determine_setitem_axis(self, row_lookup, col_lookup, row_scaler, col_scaler):
437+
"""
438+
Determine an axis along which we should do an assignment.
439+
440+
Parameters
441+
----------
442+
row_lookup: slice or list
443+
Indexer for rows
444+
col_lookup: slice or list
445+
Indexer for columns
446+
row_scaler: bool
447+
Whether indexer for rows was slacar or not
448+
col_scaler: bool
449+
Whether indexer for columns was slacer or not
450+
451+
Returns
452+
-------
453+
int or None
454+
None if this will be a both axis assignment, number of axis to assign in other cases.
455+
456+
Notes
457+
-----
458+
axis = 0: column assignment df[col] = item
459+
axis = 1: row assignment df.loc[row] = item
460+
axis = None: assignment along both axes
461+
"""
462+
if self.df.shape == (1, 1):
463+
return None if not (row_scaler ^ col_scaler) else 1 if row_scaler else 0
464+
465+
def get_axis(axis):
466+
return self.qc.index if axis == 0 else self.qc.columns
467+
468+
row_lookup_len, col_lookup_len = [
469+
len(lookup)
470+
if not isinstance(lookup, slice)
471+
else compute_sliced_len(lookup, len(get_axis(i)))
472+
for i, lookup in enumerate([row_lookup, col_lookup])
473+
]
474+
475+
if (
476+
row_lookup_len == len(self.qc.index)
477+
and col_lookup_len == 1
478+
and isinstance(self.df, DataFrame)
479+
):
480+
axis = 0
481+
elif col_lookup_len == len(self.qc.columns) and row_lookup_len == 1:
482+
axis = 1
483+
else:
484+
axis = None
485+
return axis
486+
420487

421488
class _LocIndexer(_LocationIndexerBase):
422489
"""An indexer for modin_df.loc[] functionality."""
@@ -507,7 +574,7 @@ def __setitem__(self, key, item):
507574
-------
508575
What this returns (if anything)
509576
"""
510-
row_loc, col_loc, _, __, ___ = _parse_tuple(key)
577+
row_loc, col_loc, _, row_scaler, col_scaler = _parse_tuple(key)
511578
if isinstance(row_loc, list) and len(row_loc) == 1:
512579
if row_loc[0] not in self.qc.index:
513580
index = self.qc.index.insert(len(self.qc.index), row_loc[0])
@@ -525,7 +592,14 @@ def __setitem__(self, key, item):
525592
self.qc = self.df._query_compiler
526593
else:
527594
row_lookup, col_lookup = self._compute_lookup(row_loc, col_loc)
528-
super(_LocIndexer, self).__setitem__(row_lookup, col_lookup, item)
595+
super(_LocIndexer, self).__setitem__(
596+
row_lookup,
597+
col_lookup,
598+
item,
599+
axis=self._determine_setitem_axis(
600+
row_lookup, col_lookup, row_scaler, col_scaler
601+
),
602+
)
529603

530604
def _compute_enlarge_labels(self, locator, base_index):
531605
"""
@@ -663,12 +737,19 @@ def __setitem__(self, key, item):
663737
-------
664738
What this returns (if anything)
665739
"""
666-
row_loc, col_loc, _, __, ___ = _parse_tuple(key)
740+
row_loc, col_loc, _, row_scaler, col_scaler = _parse_tuple(key)
667741
self._check_dtypes(row_loc)
668742
self._check_dtypes(col_loc)
669743

670744
row_lookup, col_lookup = self._compute_lookup(row_loc, col_loc)
671-
super(_iLocIndexer, self).__setitem__(row_lookup, col_lookup, item)
745+
super(_iLocIndexer, self).__setitem__(
746+
row_lookup,
747+
col_lookup,
748+
item,
749+
axis=self._determine_setitem_axis(
750+
row_lookup, col_lookup, row_scaler, col_scaler
751+
),
752+
)
672753

673754
def _compute_lookup(self, row_loc, col_loc):
674755
"""

modin/pandas/test/dataframe/test_indexing.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,6 @@ def test_loc_multi_index():
390390
@pytest.mark.parametrize("index", [["row1", "row2", "row3"], ["row1"]])
391391
@pytest.mark.parametrize("columns", [["col1", "col2"], ["col1"]])
392392
def test_loc_assignment(index, columns):
393-
if len(index) == 1 and len(columns) == 1:
394-
pytest.skip("See Modin issue #2253 for details")
395-
396393
md_df, pd_df = create_test_dfs(index=index, columns=columns)
397394
for i, ind in enumerate(index):
398395
for j, col in enumerate(columns):

0 commit comments

Comments
 (0)