diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 21a954d..5d4170b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies shell: bash -l {0} - run: conda install pytest locket numpy toolz pandas blosc pyzmq -c conda-forge + run: conda install pytest locket numpy toolz pandas blosc pyzmq pyarrow -c conda-forge - name: Install shell: bash -l {0} diff --git a/partd/pandas.py b/partd/pandas.py index c824c24..6c95351 100644 --- a/partd/pandas.py +++ b/partd/pandas.py @@ -108,6 +108,9 @@ def index_to_header_bytes(ind): cat = None values = ind.values + if is_extension_array_dtype(ind): + return None, dumps(ind) + header = (type(ind), {k: getattr(ind, k, None) for k in ind._attributes}, values.dtype, cat) bytes = pnp.compress(pnp.serialize(values), values.dtype) return header, bytes diff --git a/partd/tests/test_pandas.py b/partd/tests/test_pandas.py index 5e14f89..72c37dc 100644 --- a/partd/tests/test_pandas.py +++ b/partd/tests/test_pandas.py @@ -6,6 +6,11 @@ import pandas.testing as tm import os +try: + import pyarrow as pa +except ImportError: + pa = None + from partd.pandas import PandasColumns, PandasBlocks, serialize, deserialize @@ -115,3 +120,29 @@ def test_other_extension_types(): df = pd.DataFrame({"A": a}) df2 = deserialize(serialize(df)) tm.assert_frame_equal(df, df2) + +@pytest.mark.parametrize("dtype", ["Int64", "Int32", "Float64", "Float32"]) +def test_index_numeric_extension_types(dtype): + pytest.importorskip("pandas", minversion="1.4.0") + + df = pd.DataFrame({"x": [1, 2, 3]}, index=[4, 5, 6]) + df.index = df.index.astype(dtype) + df2 = deserialize(serialize(df)) + tm.assert_frame_equal(df, df2) + +@pytest.mark.parametrize( + "dtype", + [ + "string[python]", + pytest.param( + "string[pyarrow]", + marks=pytest.mark.skipif(pa is None, reason="Requires pyarrow"), + ), + ], +) +def test_index_non_numeric_extension_types(dtype): + pytest.importorskip("pandas", minversion="1.4.0") + df = pd.DataFrame({"x": [1, 2, 3]}, index=["a", "b", "c"]) + df.index = df.index.astype(dtype) + df2 = deserialize(serialize(df)) + tm.assert_frame_equal(df, df2)