Skip to content

Commit 7ae437f

Browse files
committed
Add tests for pd.NA
1 parent fa6903d commit 7ae437f

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

pygmt/tests/test_clib_to_ndarray.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,42 @@ def test_to_ndarray_pandas_series_numeric(dtype):
110110
npt.assert_array_equal(result, series)
111111

112112

113+
@pytest.mark.parametrize(
114+
"dtype",
115+
[
116+
pytest.param(pd.Int8Dtype(), id="Int8"),
117+
pytest.param(pd.Int16Dtype(), id="Int16"),
118+
pytest.param(pd.Int32Dtype(), id="Int32"),
119+
pytest.param(pd.Int64Dtype(), id="Int64"),
120+
pytest.param(pd.UInt8Dtype(), id="UInt8"),
121+
pytest.param(pd.UInt16Dtype(), id="UInt16"),
122+
pytest.param(pd.UInt32Dtype(), id="UInt32"),
123+
pytest.param(pd.UInt64Dtype(), id="UInt64"),
124+
pytest.param(pd.Float32Dtype(), id="Float32"),
125+
pytest.param(pd.Float64Dtype(), id="Float64"),
126+
pytest.param("int8[pyarrow]", marks=skip_if_no(package="pyarrow")),
127+
pytest.param("int16[pyarrow]", marks=skip_if_no(package="pyarrow")),
128+
pytest.param("int32[pyarrow]", marks=skip_if_no(package="pyarrow")),
129+
pytest.param("int64[pyarrow]", marks=skip_if_no(package="pyarrow")),
130+
pytest.param("uint8[pyarrow]", marks=skip_if_no(package="pyarrow")),
131+
pytest.param("uint16[pyarrow]", marks=skip_if_no(package="pyarrow")),
132+
pytest.param("uint32[pyarrow]", marks=skip_if_no(package="pyarrow")),
133+
pytest.param("uint64[pyarrow]", marks=skip_if_no(package="pyarrow")),
134+
pytest.param("float32[pyarrow]", marks=skip_if_no(package="pyarrow")),
135+
pytest.param("float64[pyarrow]", marks=skip_if_no(package="pyarrow")),
136+
],
137+
)
138+
def test_to_ndarray_pandas_series_numeric_with_na(dtype):
139+
"""
140+
Test the _to_ndarray function with pandas Series with NumPy dtypes and pandas NA.
141+
"""
142+
series = pd.Series([1, pd.NA, 3], dtype=dtype)
143+
assert series.dtype == dtype
144+
result = _to_ndarray(series)
145+
_check_result(result)
146+
npt.assert_array_equal(result, np.array([1, np.nan, 3], dtype=np.float64))
147+
148+
113149
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
114150
@pytest.mark.parametrize(
115151
"dtype",

0 commit comments

Comments
 (0)