Skip to content

Commit

Permalink
String dtype: fix convert_dtypes() to convert NaN-string to NA-string (
Browse files Browse the repository at this point in the history
…#59470)

* String dtype: fix convert_dtypes() to convert NaN-string to NA-string

* fix CoW tracking for conversion to python storage strings

* remove xfails
  • Loading branch information
jorisvandenbossche committed Oct 9, 2024
1 parent 3a362d8 commit 3a03c61
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 17 deletions.
10 changes: 9 additions & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,8 @@ def convert_dtypes(
-------
np.dtype, or ExtensionDtype
"""
from pandas.core.arrays.string_ import StringDtype

inferred_dtype: str | DtypeObj

if (
Expand Down Expand Up @@ -1103,12 +1105,18 @@ def convert_dtypes(
# If we couldn't do anything else, then we retain the dtype
inferred_dtype = input_array.dtype

elif (
convert_string
and isinstance(input_array.dtype, StringDtype)
and input_array.dtype.na_value is np.nan
):
inferred_dtype = pandas_dtype_func("string")

else:
inferred_dtype = input_array.dtype

if dtype_backend == "pyarrow":
from pandas.core.arrays.arrow.array import to_pyarrow_type
from pandas.core.arrays.string_ import StringDtype

assert not isinstance(inferred_dtype, str)

Expand Down
9 changes: 7 additions & 2 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,13 @@ def convert(
convert_non_numeric=True,
)
refs = None
if copy and res_values is values:
res_values = values.copy()
if (
copy
and res_values is values
or isinstance(res_values, NumpyExtensionArray)
and res_values._ndarray is values
):
res_values = res_values.copy()
elif res_values is values:
refs = self.refs

Expand Down
10 changes: 1 addition & 9 deletions pandas/tests/frame/methods/test_convert_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,15 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

import pandas as pd
import pandas._testing as tm


class TestConvertDtypes:
# TODO convert_dtypes should not use NaN variant of string dtype, but always NA
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.parametrize(
"convert_integer, expected", [(False, np.dtype("int32")), (True, "Int32")]
)
def test_convert_dtypes(
self, convert_integer, expected, string_storage, using_infer_string
):
def test_convert_dtypes(self, convert_integer, expected, string_storage):
# Specific types are tested in tests/series/test_dtypes.py
# Just check that it works for DataFrame here
df = pd.DataFrame(
Expand Down Expand Up @@ -182,7 +176,6 @@ def test_convert_dtypes_pyarrow_timestamp(self):
result = expected.convert_dtypes(dtype_backend="pyarrow")
tm.assert_series_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_convert_dtypes_avoid_block_splitting(self):
# GH#55341
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": "a"})
Expand All @@ -197,7 +190,6 @@ def test_convert_dtypes_avoid_block_splitting(self):
tm.assert_frame_equal(result, expected)
assert result._mgr.nblocks == 2

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_convert_dtypes_from_arrow(self):
# GH#56581
df = pd.DataFrame([["a", datetime.time(18, 12)]], columns=["a", "b"])
Expand Down
2 changes: 0 additions & 2 deletions pandas/tests/io/parser/dtypes/test_dtypes_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,6 @@ def test_dtype_backend_and_dtype(all_parsers):
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
def test_dtype_backend_string(all_parsers, string_storage):
# GH#36712
pa = pytest.importorskip("pyarrow")
Expand Down Expand Up @@ -507,7 +506,6 @@ def test_dtype_backend_ea_dtype_specified(all_parsers):
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
def test_dtype_backend_pyarrow(all_parsers, request):
# GH#36712
pa = pytest.importorskip("pyarrow")
Expand Down
6 changes: 3 additions & 3 deletions pandas/tests/series/methods/test_convert_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ def test_convert_dtypes(
and params[0]
and not params[1]
):
# If we would convert with convert strings then infer_objects converts
# with the option
expected_dtype = "string[pyarrow_numpy]"
# If convert_string=False and infer_objects=True, we end up with the
# default string dtype instead of preserving object for string data
expected_dtype = pd.StringDtype(na_value=np.nan)

expected = pd.Series(data, dtype=expected_dtype)
tm.assert_series_equal(result, expected)
Expand Down

0 comments on commit 3a03c61

Please sign in to comment.