diff --git a/partd/pandas.py b/partd/pandas.py index 880558c..36c1b01 100644 --- a/partd/pandas.py +++ b/partd/pandas.py @@ -211,6 +211,17 @@ def join(dfs): if not dfs: return pd.DataFrame() else: - return pd.concat(dfs) + result = pd.concat(dfs) + dtypes = { + col: "category" + for col in result.columns + if ( + isinstance(dfs[0][col].dtype, pd.CategoricalDtype) + and not isinstance(result[col].dtype, pd.CategoricalDtype) + ) + } + if dtypes: + result = result.astype(dtypes) + return result PandasBlocks = partial(Encode, serialize, deserialize, join) diff --git a/partd/tests/test_pandas.py b/partd/tests/test_pandas.py index 72c37dc..f64804b 100644 --- a/partd/tests/test_pandas.py +++ b/partd/tests/test_pandas.py @@ -146,3 +146,17 @@ def test_index_non_numeric_extension_types(dtype): df.index = df.index.astype(dtype) df2 = deserialize(serialize(df)) tm.assert_frame_equal(df, df2) + + +def test_categorical_concat(): + pytest.importorskip("pandas", minversion="2") + + df1 = pd.DataFrame({"a": ["x", "y"]}, dtype="category") + df2 = pd.DataFrame({"a": ["y", "z"]}, dtype="category") + + with PandasBlocks() as p: + p.append({'x': df1}) + p.append({'x': df2}) + + result = p.get(["x"]) + pd.testing.assert_frame_equal(result[0], pd.concat([df1, df2]).astype("category"))