From 9cd9078dfe645543a3f1fff0a776e080800e6820 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Wed, 3 May 2023 21:37:59 -0400 Subject: [PATCH] array API fixes for astype --- xarray/core/accessor_str.py | 73 ++++++++++++++++++++++--------------- xarray/core/variable.py | 6 +-- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index c6c4af87d1c..31028f10350 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -51,6 +51,7 @@ import numpy as np +from xarray.core import duck_array_ops from xarray.core.computation import apply_ufunc from xarray.core.types import T_DataArray @@ -2085,13 +2086,16 @@ def _get_res_multi(val, pat): else: # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 - return self._apply( - func=_get_res_multi, - func_args=(pat,), - dtype=np.object_, - output_core_dims=[[dim]], - output_sizes={dim: maxgroups}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + func=_get_res_multi, + func_args=(pat,), + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: maxgroups}, + ), + self._obj.dtype.kind, + ) def extractall( self, @@ -2258,15 +2262,18 @@ def _get_res(val, ipat, imaxcount=maxcount, dtype=self._obj.dtype): return res - return self._apply( - # dtype MUST be object or strings can be truncated - # See: https://github.com/numpy/numpy/issues/8352 - func=_get_res, - func_args=(pat,), - dtype=np.object_, - output_core_dims=[[group_dim, match_dim]], - output_sizes={group_dim: maxgroups, match_dim: maxcount}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + # dtype MUST be object or strings can be truncated + # See: https://github.com/numpy/numpy/issues/8352 + func=_get_res, + func_args=(pat,), + dtype=np.object_, + output_core_dims=[[group_dim, match_dim]], + output_sizes={group_dim: maxgroups, match_dim: maxcount}, + ), + self._obj.dtype.kind, + ) def findall( self, @@ -2385,13 +2392,16 @@ def _partitioner( # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 - return self._apply( - func=arrfunc, - func_args=(sep,), - dtype=np.object_, - output_core_dims=[[dim]], - output_sizes={dim: 3}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + func=arrfunc, + func_args=(sep,), + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: 3}, + ), + self._obj.dtype.kind, + ) def partition( self, @@ -2510,13 +2520,16 @@ def _dosplit(mystr, sep, maxsplit=maxsplit, dtype=self._obj.dtype): # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 - return self._apply( - func=_dosplit, - func_args=(sep,), - dtype=np.object_, - output_core_dims=[[dim]], - output_sizes={dim: maxsplit}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + func=_dosplit, + func_args=(sep,), + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: maxsplit}, + ), + self._obj.dtype.kind, + ) def split( self, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 25dfbf85556..feedd891c8d 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1420,7 +1420,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): pads = [(0, 0) if d != dim else dim_pad for d in self.dims] data = np.pad( - trimmed_data.astype(dtype), + duck_array_ops.astype(trimmed_data, dtype), pads, mode="constant", constant_values=fill_value, @@ -1569,7 +1569,7 @@ def pad( pad_option_kwargs["reflect_type"] = reflect_type array = np.pad( - self.data.astype(dtype, copy=False), + duck_array_ops.astype(self.data, dtype, copy=False), pad_width_by_index, mode=mode, **pad_option_kwargs, @@ -2437,7 +2437,7 @@ def rolling_window( """ if fill_value is dtypes.NA: # np.nan is passed dtype, fill_value = dtypes.maybe_promote(self.dtype) - var = self.astype(dtype, copy=False) + var = duck_array_ops.astype(self, dtype, copy=False) else: dtype = self.dtype var = self