Skip to content

Commit

Permalink
Na return value
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl committed Aug 16, 2023
1 parent 6b26309 commit d862eca
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 34 deletions.
5 changes: 4 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ class StringDtype(StorageExtensionDtype):
#: StringDtype().na_value uses pandas.NA

This comment has been minimized.

Copy link
@jbrockmendel

jbrockmendel Aug 22, 2023

Member

is this comment no longer accurate?

This comment has been minimized.

Copy link
@phofl

phofl Aug 22, 2023

Author Member

good catch

@property
def na_value(self) -> libmissing.NAType:
return libmissing.NA
if self.storage == "pyarrow_numpy":
return np.nan
else:
return libmissing.NA

_metadata = ("storage",)

Expand Down
14 changes: 14 additions & 0 deletions pandas/tests/strings/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
import numpy as np

import pandas as pd

object_pyarrow_numpy = ("object", "string[pyarrow_numpy]")


def _convert_na_value(ser, expected):
if ser.dtype != object:
if ser.dtype.storage == "pyarrow_numpy":
expected = expected.fillna(np.nan)
else:
# GH#18463
expected = expected.fillna(pd.NA)
return expected
9 changes: 5 additions & 4 deletions pandas/tests/strings/test_find_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
Series,
_testing as tm,
)
from pandas.tests.strings import object_pyarrow_numpy
from pandas.tests.strings import (
_convert_na_value,
object_pyarrow_numpy,
)

# --------------------------------------------------------------------------------------
# str.contains
Expand Down Expand Up @@ -780,9 +783,7 @@ def test_findall(any_string_dtype):
ser = Series(["fooBAD__barBAD", np.nan, "foo", "BAD"], dtype=any_string_dtype)
result = ser.str.findall("BAD[_]*")
expected = Series([["BAD__", "BAD"], np.nan, [], ["BAD"]])
if ser.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
expected = _convert_na_value(ser, expected)
tm.assert_series_equal(result, expected)


Expand Down
44 changes: 15 additions & 29 deletions pandas/tests/strings/test_split_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
Series,
_testing as tm,
)
from pandas.tests.strings import (
_convert_na_value,
object_pyarrow_numpy,
)


@pytest.mark.parametrize("method", ["split", "rsplit"])
Expand All @@ -20,9 +24,7 @@ def test_split(any_string_dtype, method):

result = getattr(values.str, method)("_")
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
if values.dtype != object:
# GH#18463
exp = exp.fillna(pd.NA)
exp = _convert_na_value(values, exp)
tm.assert_series_equal(result, exp)


Expand All @@ -32,9 +34,7 @@ def test_split_more_than_one_char(any_string_dtype, method):
values = Series(["a__b__c", "c__d__e", np.nan, "f__g__h"], dtype=any_string_dtype)
result = getattr(values.str, method)("__")
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
if values.dtype != object:
# GH#18463
exp = exp.fillna(pd.NA)
exp = _convert_na_value(values, exp)
tm.assert_series_equal(result, exp)

result = getattr(values.str, method)("__", expand=False)
Expand All @@ -46,9 +46,7 @@ def test_split_more_regex_split(any_string_dtype):
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype)
result = values.str.split("[,_]")
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
if values.dtype != object:
# GH#18463
exp = exp.fillna(pd.NA)
exp = _convert_na_value(values, exp)
tm.assert_series_equal(result, exp)


Expand Down Expand Up @@ -118,8 +116,8 @@ def test_split_object_mixed(expand, method):
def test_split_n(any_string_dtype, method, n):
s = Series(["a b", pd.NA, "b c"], dtype=any_string_dtype)
expected = Series([["a", "b"], pd.NA, ["b", "c"]])

result = getattr(s.str, method)(" ", n=n)
expected = _convert_na_value(s, expected)
tm.assert_series_equal(result, expected)


Expand All @@ -128,9 +126,7 @@ def test_rsplit(any_string_dtype):
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype)
result = values.str.rsplit("[,_]")
exp = Series([["a,b_c"], ["c_d,e"], np.nan, ["f,g,h"]])
if values.dtype != object:
# GH#18463
exp = exp.fillna(pd.NA)
exp = _convert_na_value(values, exp)
tm.assert_series_equal(result, exp)


Expand All @@ -139,9 +135,7 @@ def test_rsplit_max_number(any_string_dtype):
values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype)
result = values.str.rsplit("_", n=1)
exp = Series([["a_b", "c"], ["c_d", "e"], np.nan, ["f_g", "h"]])
if values.dtype != object:
# GH#18463
exp = exp.fillna(pd.NA)
exp = _convert_na_value(values, exp)
tm.assert_series_equal(result, exp)


Expand Down Expand Up @@ -390,7 +384,7 @@ def test_split_nan_expand(any_string_dtype):
# check that these are actually np.nan/pd.NA and not None
# TODO see GH 18463
# tm.assert_frame_equal does not differentiate
if any_string_dtype == "object":
if any_string_dtype in object_pyarrow_numpy:
assert all(np.isnan(x) for x in result.iloc[1])
else:
assert all(x is pd.NA for x in result.iloc[1])
Expand Down Expand Up @@ -455,9 +449,7 @@ def test_partition_series_more_than_one_char(method, exp, any_string_dtype):
s = Series(["a__b__c", "c__d__e", np.nan, "f__g__h", None], dtype=any_string_dtype)
result = getattr(s.str, method)("__", expand=False)
expected = Series(exp)
if s.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
expected = _convert_na_value(s, expected)
tm.assert_series_equal(result, expected)


Expand All @@ -480,9 +472,7 @@ def test_partition_series_none(any_string_dtype, method, exp):
s = Series(["a b c", "c d e", np.nan, "f g h", None], dtype=any_string_dtype)
result = getattr(s.str, method)(expand=False)
expected = Series(exp)
if s.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
expected = _convert_na_value(s, expected)
tm.assert_series_equal(result, expected)


Expand All @@ -505,9 +495,7 @@ def test_partition_series_not_split(any_string_dtype, method, exp):
s = Series(["abc", "cde", np.nan, "fgh", None], dtype=any_string_dtype)
result = getattr(s.str, method)("_", expand=False)
expected = Series(exp)
if s.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
expected = _convert_na_value(s, expected)
tm.assert_series_equal(result, expected)


Expand All @@ -531,9 +519,7 @@ def test_partition_series_unicode(any_string_dtype, method, exp):

result = getattr(s.str, method)("_", expand=False)
expected = Series(exp)
if s.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
expected = _convert_na_value(s, expected)
tm.assert_series_equal(result, expected)


Expand Down

0 comments on commit d862eca

Please sign in to comment.