Skip to content

Commit 55d05c8

Browse files
dchigarevaregm
authored andcommitted
[FIX] Fix modin-project#1683 - losing index names in pd.concat (modin-project#1684)
Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
1 parent b78a628 commit 55d05c8

File tree

3 files changed

+124
-48
lines changed

3 files changed

+124
-48
lines changed

modin/pandas/concat.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# governing permissions and limitations under the License.
1313

1414
import pandas
15+
import numpy as np
1516

1617
from typing import Hashable, Iterable, Mapping, Optional, Union
1718
from pandas._typing import FrameOrSeriesUnion
19+
from pandas.core.dtypes.common import is_list_like
1820

21+
from modin.backends.base.query_compiler import BaseQueryCompiler
1922
from .dataframe import DataFrame
2023
from .series import Series
2124

@@ -108,8 +111,18 @@ def concat(
108111
new_idx_labels = {
109112
k: v.index if axis == 0 else v.columns for k, v in zip(keys, objs)
110113
}
111-
tuples = [(k, o) for k, obj in new_idx_labels.items() for o in obj]
114+
tuples = [
115+
(k, *o) if isinstance(o, tuple) else (k, o)
116+
for k, obj in new_idx_labels.items()
117+
for o in obj
118+
]
112119
new_idx = pandas.MultiIndex.from_tuples(tuples)
120+
if names is not None:
121+
new_idx.names = names
122+
else:
123+
old_name = _determine_name(objs, axis)
124+
if old_name is not None:
125+
new_idx.names = [None] + old_name
113126
else:
114127
new_idx = None
115128
new_query_compiler = objs[0].concat(
@@ -132,3 +145,35 @@ def concat(
132145
else:
133146
result_df.columns = new_idx
134147
return result_df
148+
149+
150+
def _determine_name(objs: Iterable[BaseQueryCompiler], axis: Union[int, str]):
151+
"""
152+
Determine names of index after concatenation along passed axis
153+
154+
Parameters
155+
----------
156+
objs : iterable of QueryCompilers
157+
objects to concatenate
158+
159+
axis : int or str
160+
the axis to concatenate along
161+
162+
Returns
163+
-------
164+
`list` with single element - computed index name, `None` if it could not
165+
be determined
166+
"""
167+
axis = pandas.DataFrame()._get_axis_number(axis)
168+
169+
def get_names(obj):
170+
return obj.columns.names if axis else obj.index.names
171+
172+
names = np.array([get_names(obj) for obj in objs])
173+
174+
# saving old name, only if index names of all objs are the same
175+
if np.all(names == names[0]):
176+
# we must do this check to avoid this calls `list(str_like_name)`
177+
return list(names[0]) if is_list_like(names[0]) else [names[0]]
178+
else:
179+
return None

modin/pandas/test/test_concat.py

Lines changed: 19 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,57 +17,11 @@
1717

1818
import modin.pandas as pd
1919
from modin.pandas.utils import from_pandas
20-
from .utils import df_equals
20+
from .utils import df_equals, generate_dfs, generate_multiindex_dfs, generate_none_dfs
2121

2222
pd.DEFAULT_NPARTITIONS = 4
2323

2424

25-
def generate_dfs():
26-
df = pandas.DataFrame(
27-
{
28-
"col1": [0, 1, 2, 3],
29-
"col2": [4, 5, 6, 7],
30-
"col3": [8, 9, 10, 11],
31-
"col4": [12, 13, 14, 15],
32-
"col5": [0, 0, 0, 0],
33-
}
34-
)
35-
36-
df2 = pandas.DataFrame(
37-
{
38-
"col1": [0, 1, 2, 3],
39-
"col2": [4, 5, 6, 7],
40-
"col3": [8, 9, 10, 11],
41-
"col6": [12, 13, 14, 15],
42-
"col7": [0, 0, 0, 0],
43-
}
44-
)
45-
return df, df2
46-
47-
48-
def generate_none_dfs():
49-
df = pandas.DataFrame(
50-
{
51-
"col1": [0, 1, 2, 3],
52-
"col2": [4, 5, None, 7],
53-
"col3": [8, 9, 10, 11],
54-
"col4": [12, 13, 14, 15],
55-
"col5": [None, None, None, None],
56-
}
57-
)
58-
59-
df2 = pandas.DataFrame(
60-
{
61-
"col1": [0, 1, 2, 3],
62-
"col2": [4, 5, 6, 7],
63-
"col3": [8, 9, 10, 11],
64-
"col6": [12, 13, 14, 15],
65-
"col7": [0, 0, 0, 0],
66-
}
67-
)
68-
return df, df2
69-
70-
7125
def test_df_concat():
7226
df, df2 = generate_dfs()
7327

@@ -207,3 +161,21 @@ def test_concat_with_empty_frame():
207161
pd.concat([modin_empty_df, modin_row]),
208162
pandas.concat([pandas_empty_df, pandas_row]),
209163
)
164+
165+
166+
@pytest.mark.parametrize("axis", [0, 1])
167+
@pytest.mark.parametrize("names", [False, True])
168+
def test_concat_multiindex(axis, names):
169+
pd_df1, pd_df2 = generate_multiindex_dfs(axis=axis)
170+
md_df1, md_df2 = map(from_pandas, [pd_df1, pd_df2])
171+
172+
keys = ["first", "second"]
173+
if names:
174+
names = [str(i) for i in np.arange(pd_df1.axes[axis].nlevels + 1)]
175+
else:
176+
names = None
177+
178+
df_equals(
179+
pd.concat([md_df1, md_df2], keys=keys, axis=axis, names=names),
180+
pandas.concat([pd_df1, pd_df2], keys=keys, axis=axis, names=names),
181+
)

modin/pandas/test/utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,62 @@ def execute_callable(fn, md_kwargs={}, pd_kwargs={}):
592592

593593
def create_test_dfs(*args, **kwargs):
594594
return pd.DataFrame(*args, **kwargs), pandas.DataFrame(*args, **kwargs)
595+
596+
597+
def generate_dfs():
598+
df = pandas.DataFrame(
599+
{
600+
"col1": [0, 1, 2, 3],
601+
"col2": [4, 5, 6, 7],
602+
"col3": [8, 9, 10, 11],
603+
"col4": [12, 13, 14, 15],
604+
"col5": [0, 0, 0, 0],
605+
}
606+
)
607+
608+
df2 = pandas.DataFrame(
609+
{
610+
"col1": [0, 1, 2, 3],
611+
"col2": [4, 5, 6, 7],
612+
"col3": [8, 9, 10, 11],
613+
"col6": [12, 13, 14, 15],
614+
"col7": [0, 0, 0, 0],
615+
}
616+
)
617+
return df, df2
618+
619+
620+
def generate_multiindex_dfs(axis=1):
621+
def generate_multiindex(index):
622+
return pandas.MultiIndex.from_tuples(
623+
[("a", x) for x in index.values], names=["name1", "name2"]
624+
)
625+
626+
df1, df2 = generate_dfs()
627+
df1.axes[axis], df2.axes[axis] = map(
628+
generate_multiindex, [df1.axes[axis], df2.axes[axis]]
629+
)
630+
return df1, df2
631+
632+
633+
def generate_none_dfs():
634+
df = pandas.DataFrame(
635+
{
636+
"col1": [0, 1, 2, 3],
637+
"col2": [4, 5, None, 7],
638+
"col3": [8, 9, 10, 11],
639+
"col4": [12, 13, 14, 15],
640+
"col5": [None, None, None, None],
641+
}
642+
)
643+
644+
df2 = pandas.DataFrame(
645+
{
646+
"col1": [0, 1, 2, 3],
647+
"col2": [4, 5, 6, 7],
648+
"col3": [8, 9, 10, 11],
649+
"col6": [12, 13, 14, 15],
650+
"col7": [0, 0, 0, 0],
651+
}
652+
)
653+
return df, df2

0 commit comments

Comments
 (0)