Skip to content

Commit

Permalink
implement requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
toddrjen committed Dec 31, 2020
1 parent 8e05f35 commit 7d7dd6f
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 277 deletions.
22 changes: 13 additions & 9 deletions xarray/core/accessor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,18 @@ def _apply(
)

def _re_compile(
self, pat: Union[str, bytes, Pattern], flags: int, case: bool = None
self,
pat: Union[str, bytes, Pattern, Any],
flags: int,
case: bool = None,
) -> Pattern:
is_compiled_re = isinstance(pat, self._pattern_type)

if is_compiled_re and flags != 0:
raise ValueError("flags cannot be set when pat is a compiled regex")
raise ValueError("Flags cannot be set when pat is a compiled regex.")

if is_compiled_re and case is not None:
raise ValueError("case cannot be set when pat is a compiled regex")
raise ValueError("Case cannot be set when pat is a compiled regex.")

if is_compiled_re:
# no-op, needed to tell mypy this isn't a string
Expand All @@ -204,7 +207,8 @@ def _re_compile(
flags |= re.IGNORECASE

pat = self._stringify(pat)
return re.compile(pat, flags=flags)
func = lambda x: re.compile(x, flags=flags)
return self._apply(func, obj=pat, dtype=np.object_)

def len(self) -> Any:
"""
Expand Down Expand Up @@ -233,7 +237,7 @@ def __add__(

def __mul__(
self,
num: int,
num: Union[int, Any],
) -> Any:
return self.repeat(num)

Expand Down Expand Up @@ -351,7 +355,7 @@ def f(x, istart, istop, irepl):
def cat(
self,
*others,
sep: Any = "",
sep: Union[str, bytes, Any] = "",
) -> Any:
"""
Concatenate strings elementwise in the DataArray with other strings.
Expand Down Expand Up @@ -436,7 +440,7 @@ def cat(
def join(
self,
dim: Hashable = None,
sep: Any = "",
sep: Union[str, bytes, Any] = "",
) -> Any:
"""
Concatenate strings in a DataArray along a particular dimension.
Expand Down Expand Up @@ -777,7 +781,7 @@ def isupper(self) -> Any:

def count(
self,
pat: Union[str, bytes, Pattern],
pat: Union[str, bytes, Pattern, Any],
flags: int = 0,
case: bool = True,
) -> Any:
Expand Down Expand Up @@ -1558,7 +1562,7 @@ def extract(

if dim is None and pat.groups != 1:
raise ValueError(
"dim must be specified if more than one capture group is given."
"Dimension must be specified if more than one capture group is given."
)

if dim is not None and dim in self._obj.dims:
Expand Down
Loading

0 comments on commit 7d7dd6f

Please sign in to comment.