Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ArrowStringArray] fix test_astype_int, test_astype_float #41018

Merged
merged 4 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@
from pandas.util._decorators import doc
from pandas.util._validators import validate_fillna_kwargs

from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.common import (
is_array_like,
is_bool_dtype,
is_dtype_equal,
is_integer,
is_integer_dtype,
is_object_dtype,
is_scalar,
is_string_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import register_extension_dtype
from pandas.core.dtypes.missing import isna
Expand All @@ -48,6 +51,7 @@
from pandas.core.arrays.base import ExtensionArray
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.integer import Int64Dtype
from pandas.core.arrays.numeric import NumericDtype
from pandas.core.arrays.string_ import StringDtype
from pandas.core.indexers import (
check_array_indexer,
Expand Down Expand Up @@ -290,10 +294,14 @@ def to_numpy( # type: ignore[override]
"""
# TODO: copy argument is ignored

if na_value is lib.no_default:
na_value = self._dtype.na_value
result = self._data.__array__(dtype=dtype)
result[isna(result)] = na_value
result = np.array(self._data, dtype=dtype)
if self._data.null_count > 0:
if na_value is lib.no_default:
if dtype and np.issubdtype(dtype, np.floating):
return result
na_value = self._dtype.na_value
mask = self.isna()
result[mask] = na_value
return result

def __len__(self) -> int:
Expand Down Expand Up @@ -737,6 +745,24 @@ def value_counts(self, dropna: bool = True) -> Series:

return Series(counts, index=index).astype("Int64")

def astype(self, dtype, copy=True):
jreback marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you type here (followups are ok)

dtype = pandas_dtype(dtype)

if is_dtype_equal(dtype, self.dtype):
if copy:
return self.copy()
return self

elif isinstance(dtype, NumericDtype):
data = self._data.cast(pa.from_numpy_dtype(dtype.numpy_dtype))
return dtype.__from_arrow__(data)

elif isinstance(dtype, ExtensionDtype):
cls = dtype.construct_array_type()
return cls._from_sequence(self, dtype=dtype, copy=copy)

return super().astype(dtype, copy)

# ------------------------------------------------------------------------
# String methods interface

Expand Down
31 changes: 15 additions & 16 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Tests for the str accessors are in pandas/tests/strings/test_string_array.py
"""

import re

import numpy as np
import pytest

Expand Down Expand Up @@ -325,32 +327,29 @@ def test_from_sequence_no_mutate(copy, cls, request):
tm.assert_numpy_array_equal(nan_arr, expected)


def test_astype_int(dtype, request):
if dtype == "arrow_string":
reason = "Cannot interpret 'Int64Dtype()' as a data type"
mark = pytest.mark.xfail(raises=TypeError, reason=reason)
request.node.add_marker(mark)
def test_astype_int(dtype):
arr = pd.array(["1", "2", "3"], dtype=dtype)
result = arr.astype("int64")
expected = np.array([1, 2, 3], dtype="int64")
tm.assert_numpy_array_equal(result, expected)

arr = pd.array(["1", pd.NA, "3"], dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob best in a dedicate _errors test

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. done in simonjayhawkins@e1577d4

will open a PR with other follow-ups

msg = re.escape("int() argument must be a string, a bytes-like object or a number")
with pytest.raises(TypeError, match=msg):
arr.astype("int64")


def test_astype_nullable_int(dtype):
arr = pd.array(["1", pd.NA, "3"], dtype=dtype)

result = arr.astype("Int64")
expected = pd.array([1, pd.NA, 3], dtype="Int64")
tm.assert_extension_array_equal(result, expected)


def test_astype_float(dtype, any_float_allowed_nullable_dtype, request):
def test_astype_float(dtype, any_float_allowed_nullable_dtype):
# Don't compare arrays (37974)

if dtype == "arrow_string":
if any_float_allowed_nullable_dtype in {"Float32", "Float64"}:
reason = "Cannot interpret 'Float32Dtype()' as a data type"
else:
reason = "float() argument must be a string or a number, not 'NAType'"
mark = pytest.mark.xfail(raises=TypeError, reason=reason)
request.node.add_marker(mark)

ser = pd.Series(["1.1", pd.NA, "3.3"], dtype=dtype)

result = ser.astype(any_float_allowed_nullable_dtype)
expected = pd.Series([1.1, np.nan, 3.3], dtype=any_float_allowed_nullable_dtype)
tm.assert_series_equal(result, expected)
Expand Down
7 changes: 5 additions & 2 deletions pandas/tests/series/methods/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,9 @@ class TestAstypeString:
# currently no way to parse IntervalArray from a list of strings
],
)
def test_astype_string_to_extension_dtype_roundtrip(self, data, dtype, request):
def test_astype_string_to_extension_dtype_roundtrip(
self, data, dtype, request, nullable_string_dtype
):
if dtype == "boolean" or (
dtype in ("period[M]", "datetime64[ns]", "timedelta64[ns]") and NaT in data
):
Expand All @@ -389,7 +391,8 @@ def test_astype_string_to_extension_dtype_roundtrip(self, data, dtype, request):
request.node.add_marker(mark)
# GH-40351
s = Series(data, dtype=dtype)
tm.assert_series_equal(s, s.astype("string").astype(dtype))
result = s.astype(nullable_string_dtype).astype(dtype)
tm.assert_series_equal(result, s)


class TestAstypeCategorical:
Expand Down