From 9a65e777fce940a2de3e9ace6f5119023d44a011 Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev <62142979+dchigarev@users.noreply.github.com> Date: Thu, 30 Jul 2020 01:05:55 +0300 Subject: [PATCH] FIX-#1154: properly process UDFs (#1845) Signed-off-by: Dmitry Chigarev --- modin/backends/pandas/query_compiler.py | 21 ++++++----- modin/engines/base/frame/data.py | 48 ++++++++++++++++++++----- modin/pandas/base.py | 8 +++-- modin/pandas/groupby.py | 3 +- modin/pandas/test/test_dataframe.py | 12 +++++++ modin/pandas/test/utils.py | 14 ++++++++ modin/pandas/utils.py | 11 ++++++ 7 files changed, 96 insertions(+), 21 deletions(-) diff --git a/modin/backends/pandas/query_compiler.py b/modin/backends/pandas/query_compiler.py index 4d585658893..ef9c0314230 100644 --- a/modin/backends/pandas/query_compiler.py +++ b/modin/backends/pandas/query_compiler.py @@ -22,6 +22,7 @@ from modin.backends.base.query_compiler import BaseQueryCompiler from modin.error_message import ErrorMessage +from modin.pandas.utils import try_cast_to_pandas, wrap_udf_function from modin.data_management.functions import ( FoldFunction, MapFunction, @@ -1896,6 +1897,10 @@ def apply(self, func, axis, *args, **kwargs): Returns: A new PandasQueryCompiler. """ + # if any of args contain modin object, we should + # convert it to pandas + args = try_cast_to_pandas(args) + kwargs = try_cast_to_pandas(kwargs) if isinstance(func, str): return self._apply_text_func_elementwise(func, axis, *args, **kwargs) elif callable(func): @@ -1920,7 +1925,7 @@ def _apply_text_func_elementwise(self, func, axis, *args, **kwargs): assert isinstance(func, str) kwargs["axis"] = axis new_modin_frame = self._modin_frame._apply_full_axis( - axis, lambda df: getattr(df, func)(**kwargs) + axis, lambda df: df.apply(func, *args, **kwargs) ) return self.__constructor__(new_modin_frame) @@ -1942,6 +1947,7 @@ def dict_apply_builder(df, func_dict={}): # all objects are `DataFrame`s. return pandas.DataFrame(df.apply(func_dict, *args, **kwargs)) + func = {k: wrap_udf_function(v) if callable(v) else v for k, v in func.items()} return self.__constructor__( self._modin_frame._apply_full_axis_select_indices( axis, dict_apply_builder, func, keep_remaining=False @@ -1969,6 +1975,7 @@ def _list_like_func(self, func, axis, *args, **kwargs): if axis == 1 else self.columns ) + func = [wrap_udf_function(f) if callable(f) else f for f in func] new_modin_frame = self._modin_frame._apply_full_axis( axis, lambda df: pandas.DataFrame(df.apply(func, axis, *args, **kwargs)), @@ -1987,14 +1994,10 @@ def _callable_func(self, func, axis, *args, **kwargs): Returns: A new PandasQueryCompiler. """ - if isinstance(pandas.DataFrame().apply(func), pandas.Series): - new_modin_frame = self._modin_frame._fold_reduce( - axis, lambda df: df.apply(func, axis=axis, *args, **kwargs) - ) - else: - new_modin_frame = self._modin_frame._apply_full_axis( - axis, lambda df: df.apply(func, axis=axis, *args, **kwargs) - ) + func = wrap_udf_function(func) + new_modin_frame = self._modin_frame._apply_full_axis( + axis, lambda df: df.apply(func, axis=axis, *args, **kwargs) + ) return self.__constructor__(new_modin_frame) # END UDF diff --git a/modin/engines/base/frame/data.py b/modin/engines/base/frame/data.py index 462b50d1ce7..d59fbba5dc8 100644 --- a/modin/engines/base/frame/data.py +++ b/modin/engines/base/frame/data.py @@ -211,7 +211,7 @@ def _set_columns(self, new_columns): self._dtypes.index = new_columns self._apply_index_objs(axis=1) - def _set_axis(self, axis, new_axis): + def _set_axis(self, axis, new_axis, cache_only=False): """Replaces the current labels at the specified axis with the new one Parameters @@ -220,11 +220,20 @@ def _set_axis(self, axis, new_axis): Axis to set labels along new_axis : Index, The replacement labels + cache_only : bool, + Whether to change only external indices, or propagate it + into partitions """ if axis: - self._set_columns(new_axis) + if not cache_only: + self._set_columns(new_axis) + else: + self._columns_cache = ensure_index(new_axis) else: - self._set_index(new_axis) + if not cache_only: + self._set_index(new_axis) + else: + self._index_cache = ensure_index(new_axis) columns = property(_get_columns, _set_columns) index = property(_get_index, _set_index) @@ -256,7 +265,7 @@ def _filter_empties(self): self._column_widths_cache = [w for w in self._column_widths if w > 0] self._row_lengths_cache = [r for r in self._row_lengths if r > 0] - def _validate_axis_equality(self, axis: int): + def _validate_axis_equality(self, axis: int, force: bool = False): """ Validates internal and external indices of modin_frame at the specified axis. @@ -264,22 +273,32 @@ def _validate_axis_equality(self, axis: int): ---------- axis : int, Axis to validate indices along + force : bool, + Whether to update external indices with internal if their lengths + do not match or raise an exception in that case. """ internal_axis = self._frame_mgr_cls.get_indices( axis, self._partitions, lambda df: df.axes[axis] ) is_equals = self.axes[axis].equals(internal_axis) + is_lenghts_matches = len(self.axes[axis]) == len(internal_axis) if not is_equals: - self._set_axis(axis, self.axes[axis]) + if force: + new_axis = self.axes[axis] if is_lenghts_matches else internal_axis + self._set_axis(axis, new_axis, cache_only=not is_lenghts_matches) + else: + self._set_axis( + axis, self.axes[axis], + ) def _validate_internal_indices(self, mode=None, **kwargs): """ Validates and optionally updates internal and external indices of modin_frame in specified mode. There is 3 modes supported: - 1. "reduced" - validates and updates indices on that axes + 1. "reduced" - force validates on that axes where external indices is ["__reduced__"] - 2. "all" - validates indices at all axes, optionally updates - internal indices if `update` parameter specified in kwargs + 2. "all" - validates indices at all axes, optionally force + if `force` parameter specified in kwargs 3. "custom" - validation follows arguments specified in kwargs. Parameters @@ -287,10 +306,16 @@ def _validate_internal_indices(self, mode=None, **kwargs): mode : str or bool, default None validate_index : bool, (optional, could be specified via `mode`) validate_columns : bool, (optional, could be specified via `mode`) + force : bool (optional, could be specified via `mode`) + Whether to update external indices with internal if their lengths + do not match or raise an exception in that case. """ if isinstance(mode, bool): + is_force = mode mode = "all" + else: + is_force = kwargs.get("force", False) reduced_sample = pandas.Index(["__reduced__"]) args_dict = { @@ -298,8 +323,13 @@ def _validate_internal_indices(self, mode=None, **kwargs): "reduced": { "validate_index": self.index.equals(reduced_sample), "validate_columns": self.columns.equals(reduced_sample), + "force": True, + }, + "all": { + "validate_index": True, + "validate_columns": True, + "force": is_force, }, - "all": {"validate_index": True, "validate_columns": True}, } args = args_dict.get(mode, args_dict["custom"]) diff --git a/modin/pandas/base.py b/modin/pandas/base.py index 10d9b01e050..03dac9e0565 100644 --- a/modin/pandas/base.py +++ b/modin/pandas/base.py @@ -583,7 +583,9 @@ def apply( axis = self._get_axis_number(axis) ErrorMessage.non_verified_udf() if isinstance(func, str): - result = self._query_compiler.apply(func, axis=axis, *args, **kwds) + result = self._query_compiler.apply( + func, axis=axis, raw=raw, result_type=result_type, *args, **kwds, + ) if isinstance(result, BasePandasDataset): return result._query_compiler return result @@ -601,7 +603,9 @@ def apply( ) elif not callable(func) and not is_list_like(func): raise TypeError("{} object is not callable".format(type(func))) - query_compiler = self._query_compiler.apply(func, axis, args=args, **kwds) + query_compiler = self._query_compiler.apply( + func, axis, args=args, raw=raw, result_type=result_type, **kwds, + ) return query_compiler def asfreq(self, freq, method=None, how=None, normalize=False, fill_value=None): diff --git a/modin/pandas/groupby.py b/modin/pandas/groupby.py index 383f23ab99b..c9f35f7a94f 100644 --- a/modin/pandas/groupby.py +++ b/modin/pandas/groupby.py @@ -19,7 +19,7 @@ from modin.error_message import ErrorMessage -from .utils import _inherit_docstrings +from .utils import _inherit_docstrings, wrap_udf_function from .series import Series @@ -644,6 +644,7 @@ def _apply_agg_function(self, f, drop=True, *args, **kwargs): """ assert callable(f), "'{0}' object is not callable".format(type(f)) + f = wrap_udf_function(f) if self._is_multi_by: return self._default_to_pandas(f, *args, **kwargs) diff --git a/modin/pandas/test/test_dataframe.py b/modin/pandas/test/test_dataframe.py index 43786a5e2c9..487e791c63a 100644 --- a/modin/pandas/test/test_dataframe.py +++ b/modin/pandas/test/test_dataframe.py @@ -58,6 +58,8 @@ create_test_dfs, test_data_small_values, test_data_small_keys, + udf_func_values, + udf_func_keys, ) pd.DEFAULT_NPARTITIONS = 4 @@ -1812,6 +1814,16 @@ def test_apply_numeric(self, request, data, func, axis): pandas_result = pandas_df.apply(lambda df: df.drop(key), axis=1) df_equals(modin_result, pandas_result) + @pytest.mark.parametrize("func", udf_func_values, ids=udf_func_keys) + @pytest.mark.parametrize("data", test_data_values, ids=test_data_keys) + def test_apply_udf(self, data, func): + eval_general( + *create_test_dfs(data), + lambda df, *args, **kwargs: df.apply(*args, **kwargs), + func=func, + other=lambda df: df, + ) + def test_eval_df_use_case(self): frame_data = {"a": random_state.randn(10), "b": random_state.randn(10)} df = pandas.DataFrame(frame_data) diff --git a/modin/pandas/test/utils.py b/modin/pandas/test/utils.py index 27415c3d61a..ebf0ce97828 100644 --- a/modin/pandas/test/utils.py +++ b/modin/pandas/test/utils.py @@ -259,6 +259,20 @@ numeric_agg_funcs = ["sum mean", "sum sum", "sum df sum"] +udf_func = { + "return self": lambda df: lambda x, *args, **kwargs: type(x)(x.values), + "change index": lambda df: lambda x, *args, **kwargs: pandas.Series( + x.values, index=np.arange(-1, len(x.index) - 1) + ), + "return none": lambda df: lambda x, *args, **kwargs: None, + "return empty": lambda df: lambda x, *args, **kwargs: pandas.Series(), + "access self": lambda df: lambda x, other, *args, **kwargs: pandas.Series( + x.values, index=other.index + ), +} +udf_func_keys = list(udf_func.keys()) +udf_func_values = list(udf_func.values()) + # Test q values for quantiles quantiles = { "0.25": 0.25, diff --git a/modin/pandas/utils.py b/modin/pandas/utils.py index 8ed2e60d5cd..0696ef325c9 100644 --- a/modin/pandas/utils.py +++ b/modin/pandas/utils.py @@ -116,3 +116,14 @@ def try_cast_to_pandas(obj): else getattr(pandas.Series, fn_name, obj) ) return obj + + +def wrap_udf_function(func): + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + # if user accidently returns modin DataFrame or Series + # casting it back to pandas to properly process + return try_cast_to_pandas(result) + + wrapper.__name__ = func.__name__ + return wrapper