Skip to content

Commit

Permalink
Use pickle for pandas Index with extension dtype (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored Mar 7, 2023
1 parent d1faa88 commit c91ad12
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:

- name: Install dependencies
shell: bash -l {0}
run: conda install pytest locket numpy toolz pandas blosc pyzmq -c conda-forge
run: conda install pytest locket numpy toolz pandas blosc pyzmq pyarrow -c conda-forge

- name: Install
shell: bash -l {0}
Expand Down
3 changes: 3 additions & 0 deletions partd/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def index_to_header_bytes(ind):
cat = None
values = ind.values

if is_extension_array_dtype(ind):
return None, dumps(ind)

header = (type(ind), {k: getattr(ind, k, None) for k in ind._attributes}, values.dtype, cat)
bytes = pnp.compress(pnp.serialize(values), values.dtype)
return header, bytes
Expand Down
31 changes: 31 additions & 0 deletions partd/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import pandas.testing as tm
import os

try:
import pyarrow as pa
except ImportError:
pa = None

from partd.pandas import PandasColumns, PandasBlocks, serialize, deserialize


Expand Down Expand Up @@ -115,3 +120,29 @@ def test_other_extension_types():
df = pd.DataFrame({"A": a})
df2 = deserialize(serialize(df))
tm.assert_frame_equal(df, df2)

@pytest.mark.parametrize("dtype", ["Int64", "Int32", "Float64", "Float32"])
def test_index_numeric_extension_types(dtype):
pytest.importorskip("pandas", minversion="1.4.0")

df = pd.DataFrame({"x": [1, 2, 3]}, index=[4, 5, 6])
df.index = df.index.astype(dtype)
df2 = deserialize(serialize(df))
tm.assert_frame_equal(df, df2)

@pytest.mark.parametrize(
"dtype",
[
"string[python]",
pytest.param(
"string[pyarrow]",
marks=pytest.mark.skipif(pa is None, reason="Requires pyarrow"),
),
],
)
def test_index_non_numeric_extension_types(dtype):
pytest.importorskip("pandas", minversion="1.4.0")
df = pd.DataFrame({"x": [1, 2, 3]}, index=["a", "b", "c"])
df.index = df.index.astype(dtype)
df2 = deserialize(serialize(df))
tm.assert_frame_equal(df, df2)

0 comments on commit c91ad12

Please sign in to comment.