From d0b8600d31fb58d1bb1b550c2f0a77fd95e2472c Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Thu, 31 Dec 2020 01:35:25 -0500 Subject: [PATCH] more work on vectorized functions --- xarray/core/accessor_str.py | 193 +++++++++++++++++------------- xarray/tests/test_accessor_str.py | 92 +++++++++++++- 2 files changed, 201 insertions(+), 84 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 489ad25d1d7..872cd97df1b 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -185,7 +185,7 @@ def _re_compile( pat: Union[str, bytes, Pattern, Any], flags: int, case: bool = None, - ) -> Pattern: + ) -> Union[Pattern, Any]: is_compiled_re = isinstance(pat, self._pattern_type) if is_compiled_re and flags != 0: @@ -206,9 +206,15 @@ def _re_compile( if not case: flags |= re.IGNORECASE - pat = self._stringify(pat) + if getattr(pat, "dtype", None) != np.object_: + pat = self._stringify(pat) func = lambda x: re.compile(x, flags=flags) - return self._apply(func, obj=pat, dtype=np.object_) + if isinstance(pat, np.ndarray): + # apply_ufunc doesn't work for numpy arrays with output object dtypes + func = np.vectorize(func) + return func(pat) + else: + return self._apply(func, obj=pat, dtype=np.object_) def len(self) -> Any: """ @@ -262,10 +268,14 @@ def get( """ Extract character number `i` from each string in the array. + If `i` is an array-like, they are broadcast against the array and + applied elementwise. + Parameters ---------- - i : int + i : int or array-like of int Position of element to extract. + If array-like, it is broadcast. default : optional Value for out-of-range index. If not specified (None) defaults to an empty string. @@ -292,14 +302,20 @@ def slice( """ Slice substrings from each string in the array. + If `start`, `stop`, or 'step` is an array-like, they are broadcast + against the array and applied elementwise + Parameters ---------- - start : int, optional + start : int or array-like of int, optional Start position for slice operation. - stop : int, optional + If array-like, it is broadcast. + stop : int or array-like of int, optional Stop position for slice operation. - step : int, optional + If array-like, it is broadcast. + step : int or array-like of int, optional Step size for slice operation. + If array-like, it is broadcast. Returns ------- @@ -319,17 +335,17 @@ def slice_replace( Parameters ---------- - start : int, optional + start : int or array-like of int, optional Left index position to use for the slice. If not specified (None), the slice is unbounded on the left, i.e. slice from the start - of the string. - stop : int, optional + of the string. If array-like, it is broadcast. + stop : int or array-like of int, optional Right index position to use for the slice. If not specified (None), the slice is unbounded on the right, i.e. slice until the - end of the string. - repl : str, optional + end of the string. If array-like, it is broadcast. + repl : str or array-like of str, optional String for replacement. If not specified, the sliced region - is replaced with an empty string. + is replaced with an empty string. If array-like, it is broadcast. Returns ------- @@ -338,7 +354,7 @@ def slice_replace( repl = self._stringify(repl) def f(x, istart, istop, irepl): - if len(x[istart:stop]) == 0: + if len(x[istart:istop]) == 0: local_stop = istart else: local_stop = istop @@ -370,7 +386,7 @@ def cat( *others : str or array-like of str Strings or array-like of strings to concatenate elementwise with the current DataArray. - sep : str or array-like, default `""`. + sep : str or array-like, default: "". Seperator to use between strings. It is broadcast in the same way as the other input strings. If array-like, its dimensions will be placed at the end of the output array dimensions. @@ -449,10 +465,10 @@ def join( Parameters ---------- - dim : Hashable, optional + dim : hashable, optional Dimension along which the strings should be concatenated. Optional for 0D or 1D DataArrays, required for multidimensional DataArrays. - sep : str or array-like, default `""`. + sep : str or array-like, default: "". Seperator to use between strings. It is broadcast in the same way as the other input strings. If array-like, its dimensions will be placed at the end of the output array dimensions. @@ -669,7 +685,7 @@ def normalize( Parameters ---------- - form : {"NFC", "NFKC", "NFD", and "NFKD"} + form : {"NFC", "NFKC", "NFD", "NFKD"} Unicode form. Returns @@ -783,7 +799,7 @@ def count( self, pat: Union[str, bytes, Pattern, Any], flags: int = 0, - case: bool = True, + case: bool = None, ) -> Any: """ Count occurrences of pattern in each string of the array. @@ -792,11 +808,14 @@ def count( pattern is repeated in each of the string elements of the :class:`~xarray.DataArray`. + `pat` can either be a single `str` or `re.Pattern` or an array-like + of `str` or `re.Pattern`. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern A string containing a regular expression or a compiled regular - expression object. + expression object. If array-like, it is broadcast. flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. @@ -813,8 +832,8 @@ def count( """ pat = self._re_compile(pat, flags, case) - f = lambda x: len(pat.findall(x)) - return self._apply(f, dtype=int) + f = lambda x, ipat: len(ipat.findall(x)) + return self._apply(f, pat, dtype=int) def startswith( self, @@ -823,10 +842,14 @@ def startswith( """ Test if the start of each string in the array matches a pattern. + The pattern can either be a `str` or array-like of `str`. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- pat : str Character sequence. Regular expressions are not accepted. + If array-like, it is broadcast. Returns ------- @@ -845,10 +868,14 @@ def endswith( """ Test if the end of each string in the array matches a pattern. + The pattern can either be a `str` or array-like of `str`. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- pat : str Character sequence. Regular expressions are not accepted. + If array-like, it is broadcast. Returns ------- @@ -864,20 +891,22 @@ def pad( self, width: Union[int, Any], side: str = "left", - fillchar: Union[str, bytes] = " ", + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad strings in the array up to width. Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be filled with character defined in `fillchar`. + If array-like, it is broadcast. side : {"left", "right", "both"}, default: "left" Side from which to fill resulting string. - fillchar : str, default: " " - Additional character for filling, default is whitespace. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- @@ -898,34 +927,37 @@ def pad( def _padder( self, func: Callable, - width: int, - fillchar: Union[str, bytes] = " ", + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Wrapper function to handle padding operations """ fillchar = self._stringify(fillchar) - if len(fillchar) != 1: - raise TypeError("fillchar must be a character, not str") - f = lambda s, w: func(s, int(w), fillchar) - return self._apply(f, width) + def overfunc(x, iwidth, ifillchar): + if len(ifillchar) != 1: + raise TypeError("fillchar must be a character, not str") + return func(x, int(iwidth), ifillchar) + + return self._apply(overfunc, width, fillchar) def center( self, - width: int, - fillchar: Union[str, bytes] = " ", + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad left and right side of each string in the array. Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- @@ -936,19 +968,20 @@ def center( def ljust( self, - width: int, - fillchar: Union[str, bytes] = " ", + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad right side of each string in the array. Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- @@ -959,19 +992,20 @@ def ljust( def rjust( self, - width: int, - fillchar: Union[str, bytes] = " ", + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad left side of each string in the array. Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- @@ -980,10 +1014,7 @@ def rjust( func = self._obj.dtype.type.rjust return self._padder(func=func, width=width, fillchar=fillchar) - def zfill( - self, - width: int, - ) -> Any: + def zfill(self, width: Union[int, Any]) -> Any: """ Pad each string in the array by prepending '0' characters. @@ -993,9 +1024,9 @@ def zfill( Parameters ---------- - width : int + width : int or array-like of int Minimum length of resulting string; strings with length less - than `width` be prepended with '0' characters. + than `width` be prepended with '0' characters. If array-like, it is broadcast. Returns ------- @@ -1485,7 +1516,7 @@ def extract( pat : str or re.Pattern A string containing a regular expression or a compiled regular expression object. - dim : hashable or `None` + dim : hashable or None Name of the new dimension to store the captured strings in. If None, the pattern must have only one capture group and the resulting DataArray will have the same size as the original. @@ -1508,7 +1539,7 @@ def extract( ValueError `pat` has no capture groups. ValueError - `dim` is `None` and there is more than one capture group. + `dim` is None and there is more than one capture group. ValueError `case` is set when `pat` is a compiled regular expression. KeyError @@ -1619,10 +1650,10 @@ def extractall( pat : str or re.Pattern A string containing a regular expression or a compiled regular expression object. - group_dim: hashable + group_dim : hashable Name of the new dimensions corresponding to the capture groups. This dimension is added to the new DataArray first. - match_dim: hashable + match_dim : hashable Name of the new dimensions corresponding to the matches for each group. This dimension is added to the new DataArray second. case : bool, default: True @@ -1857,7 +1888,7 @@ def _partitioner( # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, -1) + return self._obj.copy().expand_dims({dim: 0}, axis=-1) f = lambda x: np.array(func(x, sep), dtype=self._obj.dtype) @@ -1888,10 +1919,10 @@ def partition( Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the 3 elements in. If `None`, place the results as list elements in an object DataArray - sep : str, default `" "` + sep : str, default: " " String to split on. Returns @@ -1924,10 +1955,10 @@ def rpartition( Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the 3 elements in. If `None`, place the results as list elements in an object DataArray - sep : str, default `" "` + sep : str, default: " " String to split on. Returns @@ -1962,10 +1993,10 @@ def _splitter( # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, -1) + return self._obj.copy().expand_dims({dim: 0}, axis=-1) f_count = lambda x: max(len(func(x, sep, maxsplit)), 1) - maxsplit = self._apply(f_count, dtype=np.int_).max().data.tolist() - 1 + maxsplit = self._apply(f_count, dtype=np.int_).max().data.item() - 1 def _dosplit(mystr, sep=sep, maxsplit=maxsplit, dtype=self._obj.dtype): res = func(mystr, sep, maxsplit) @@ -2002,14 +2033,14 @@ def split( Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the results in. If `None`, place the results as list elements in an object DataArray - sep : str, default is split on any whitespace. - String to split on. - maxsplit : int, default -1 (all) + sep : str, default: None + String to split on. If ``None`` (the default), split on any whitespace. + maxsplit : int, default: -1 Limit number of splits in output, starting from the beginning. - -1 will return all splits. + If -1 (the default), return all splits. Returns ------- @@ -2116,14 +2147,14 @@ def rsplit( Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the results in. If `None`, place the results as list elements in an object DataArray - sep : str, default is split on any whitespace. - String to split on. - maxsplit : int, default -1 (all) + sep : str, default: None + String to split on. If ``None`` (the default), split on any whitespace. + maxsplit : int, default: -1 Limit number of splits in output, starting from the end. - -1 will return all splits. + If -1 (the default), return all splits. The final number of split values may be less than this if there are no DataArray elements with that many values. @@ -2231,9 +2262,9 @@ def get_dummies( Parameters ---------- - dim : Hashable + dim : hashable Name for the dimension to place the results in. - sep : str, default `"|"`. + sep : str, default: "|". String to split on. Returns @@ -2273,7 +2304,7 @@ def get_dummies( """ # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, -1) + return self._obj.copy().expand_dims({dim: 0}, axis=-1) sep = self._stringify(sep) f_set = lambda x: set(x.split(sep)) - {self._stringify("")} diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index 5107472b791..3260d845761 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -69,10 +69,34 @@ def test_dask(): def test_count(dtype): values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) - result = values.str.count("f[o]+") + pat_str = dtype(r"f[o]+") + pat_re = re.compile(pat_str) + + result_str = values.str.count(pat_str) + result_re = values.str.count(pat_re) + expected = xr.DataArray([1, 2, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + + assert result_str.dtype == expected.dtype + assert result_re.dtype == expected.dtype + assert_equal(result_str, expected) + assert_equal(result_re, expected) + + +def test_count_array(dtype): + values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) + pat_str = np.array([r"f[o]+", r"o", r"m"]).astype(dtype) + pat_re = np.array([re.compile(x) for x in pat_str]) + + result_str = values.str.count(pat_str) + result_re = values.str.count(pat_re) + + expected = xr.DataArray([1, 4, 3]) + + assert result_str.dtype == expected.dtype + assert result_re.dtype == expected.dtype + assert_equal(result_str, expected) + assert_equal(result_re, expected) def test_contains(dtype): @@ -113,6 +137,29 @@ def test_starts_ends_with(dtype): assert_equal(result, expected) +def test_starts_ends_with_array(dtype): + values = xr.DataArray( + ["om", "foo_nom", "nom", "bar_foo", "foo_bar"], dims="X" + ).astype(dtype) + pat = xr.DataArray(["foo", "bar"], dims="Y").astype(dtype) + + result = values.str.startswith(pat) + expected = xr.DataArray( + [[False, False], [True, False], [False, False], [False, True], [True, False]], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.endswith(pat) + expected = xr.DataArray( + [[False, False], [False, False], [False, False], [True, False], [False, True]], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + def test_case_bytes(): dtype = np.bytes_ value = xr.DataArray(["SOme wOrd"]).astype(dtype) @@ -1643,6 +1690,17 @@ def test_slice(dtype): raise +def test_slice_array(dtype): + arr = xr.DataArray(["aafootwo", "aabartwo", "aabazqux"]).astype(dtype) + start = xr.DataArray([1, 2, 3]) + stop = 5 + + result = arr.str.slice(start=start, stop=stop) + exp = xr.DataArray(["afoo", "bar", "az"]).astype(dtype) + assert result.dtype == exp.dtype + assert_equal(result, exp) + + def test_slice_replace(dtype): da = lambda x: xr.DataArray(x).astype(dtype) values = da(["short", "a bit longer", "evenlongerthanthat", ""]) @@ -1688,6 +1746,22 @@ def test_slice_replace(dtype): assert_equal(result, expected) +def test_slice_replace_array(dtype): + values = xr.DataArray(["short", "a bit longer", "evenlongerthanthat", ""]).astype( + dtype + ) + start = 2 + stop = np.array([4, 5, None, 7]) + repl = "test" + + expected = xr.DataArray(["shtestt", "a test longer", "evtest", "test"]).astype( + dtype + ) + result = values.str.slice_replace(start, stop, repl) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + def test_strip_lstrip_rstrip(dtype): values = xr.DataArray([" aa ", " bb \n", "cc "]).astype(dtype) @@ -1823,6 +1897,18 @@ def test_get_default(dtype): assert_equal(result, expected) +def test_get_array(dtype): + values = xr.DataArray(["a_b_c", "c_d_e", "f_g_h"], dims=["X"]).astype(dtype) + inds = xr.DataArray([0, 2], dims=["Y"]) + + result = values.str.get(inds) + expected = xr.DataArray( + [["a", "b"], ["c", "d"], ["f", "g"]], dims=["X", "Y"] + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + def test_encode_decode(): data = xr.DataArray(["a", "b", "a\xe4"]) encoded = data.str.encode("utf-8")