Skip to content

Commit

Permalink
FIX-modin-project#1700: Fix metadata for concat and mask when axis=1
Browse files Browse the repository at this point in the history
Signed-off-by: Devin Petersohn <devin.petersohn@gmail.com>
  • Loading branch information
devin-petersohn committed Jul 24, 2020
1 parent 49f462c commit f9cdfca
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
11 changes: 7 additions & 4 deletions modin/backends/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,14 @@ def concat(self, axis, other, **kwargs):
ignore_index = kwargs.get("ignore_index", False)
other_modin_frame = [o._modin_frame for o in other]
new_modin_frame = self._modin_frame._concat(axis, other_modin_frame, join, sort)
result = self.__constructor__(new_modin_frame)
if ignore_index:
new_modin_frame.index = pandas.RangeIndex(
len(self.index) + sum(len(o.index) for o in other)
)
return self.__constructor__(new_modin_frame)
if axis == 0:
return result.reset_index(drop=True)
else:
result.columns = pandas.RangeIndex(len(result.columns))
return result
return result

# END Append/Concat/Join

Expand Down
2 changes: 1 addition & 1 deletion modin/engines/base/frame/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def mask(
new_col_widths = [len(idx) for _, idx in col_partitions_list.items()]
new_columns = self.columns[sorted(col_numeric_idx)]
if self._dtypes is not None:
new_dtypes = self.dtypes[sorted(col_numeric_idx)]
new_dtypes = self.dtypes.iloc[sorted(col_numeric_idx)]
else:
new_dtypes = None
else:
Expand Down
11 changes: 11 additions & 0 deletions modin/pandas/test/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ def test_concat_on_column():
pandas.concat([df, df2], axis="columns"),
)

modin_result = pd.concat(
[pd.Series(np.ones(10)), pd.Series(np.ones(10))], axis=1, ignore_index=True
)
pandas_result = pandas.concat(
[pandas.Series(np.ones(10)), pandas.Series(np.ones(10))],
axis=1,
ignore_index=True,
)
df_equals(modin_result, pandas_result)
assert modin_result.dtypes.equals(pandas_result.dtypes)


def test_invalid_axis_errors():
df, df2 = generate_dfs()
Expand Down

0 comments on commit f9cdfca

Please sign in to comment.