From c7278871e9326e68e015a01f596913023cb22de8 Mon Sep 17 00:00:00 2001 From: helmeleegy <40042062+helmeleegy@users.noreply.github.com> Date: Thu, 2 Feb 2023 11:09:12 -0800 Subject: [PATCH] Fix str_cat/extract/partition/replace/rpartition (#51) (core) --- modin/pandas/series_utils.py | 48 ++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/modin/pandas/series_utils.py b/modin/pandas/series_utils.py index 4399c78b516..f4c866adc37 100644 --- a/modin/pandas/series_utils.py +++ b/modin/pandas/series_utils.py @@ -141,8 +141,11 @@ def casefold(self): def cat(self, others=None, sep=None, na_rep=None, join=None): if isinstance(others, Series): others = others._to_pandas() - return self._default_to_pandas( - pandas.Series.str.cat, others=others, sep=sep, na_rep=na_rep, join=join + data = Series(query_compiler=self._query_compiler) + return data._reduce_dimension( + self._query_compiler.str_cat( + others=others, sep=sep, na_rep=na_rep, join=join + ) ) def decode(self, encoding, errors="strict"): @@ -307,7 +310,10 @@ def match(self, pat, case=True, flags=0, na=np.NaN): ) def extract(self, pat, flags=0, expand=True): - if expand: + import re + + n = re.compile(pat).groups + if expand or n > 1: from .dataframe import DataFrame return DataFrame( @@ -337,9 +343,21 @@ def lstrip(self, to_strip=None): def partition(self, sep=" ", expand=True): if sep is not None and len(sep) == 0: raise ValueError("empty separator") - return Series( - query_compiler=self._query_compiler.str_partition(sep=sep, expand=expand) - ) + + if expand: + from .dataframe import DataFrame + + return DataFrame( + query_compiler=self._query_compiler.str_partition( + sep=sep, expand=expand + ) + ) + else: + return Series( + query_compiler=self._query_compiler.str_partition( + sep=sep, expand=expand + ) + ) def removeprefix(self, prefix): return Series(query_compiler=self._query_compiler.str_removeprefix(prefix)) @@ -353,11 +371,21 @@ def repeat(self, repeats): def rpartition(self, sep=" ", expand=True): if sep is not None and len(sep) == 0: raise ValueError("empty separator") - return Series( - query_compiler=self._query_compiler.str_rpartition( - sep=sep, expand=expand + + if expand: + from .dataframe import DataFrame + + return DataFrame( + query_compiler=self._query_compiler.str_rpartition( + sep=sep, expand=expand + ) + ) + else: + return Series( + query_compiler=self._query_compiler.str_rpartition( + sep=sep, expand=expand + ) ) - ) def lower(self): return Series(query_compiler=self._query_compiler.str_lower())