Skip to content

Commit

Permalink
FIX-modin-project#1154: properly process UDFs (modin-project#1845)
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
  • Loading branch information
dchigarev authored and aregm committed Sep 16, 2020
1 parent df3f1d4 commit 0e826ac
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 21 deletions.
21 changes: 12 additions & 9 deletions modin/backends/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from modin.backends.base.query_compiler import BaseQueryCompiler
from modin.error_message import ErrorMessage
from modin.pandas.utils import try_cast_to_pandas, wrap_udf_function
from modin.data_management.functions import (
FoldFunction,
MapFunction,
Expand Down Expand Up @@ -1896,6 +1897,10 @@ def apply(self, func, axis, *args, **kwargs):
Returns:
A new PandasQueryCompiler.
"""
# if any of args contain modin object, we should
# convert it to pandas
args = try_cast_to_pandas(args)
kwargs = try_cast_to_pandas(kwargs)
if isinstance(func, str):
return self._apply_text_func_elementwise(func, axis, *args, **kwargs)
elif callable(func):
Expand All @@ -1920,7 +1925,7 @@ def _apply_text_func_elementwise(self, func, axis, *args, **kwargs):
assert isinstance(func, str)
kwargs["axis"] = axis
new_modin_frame = self._modin_frame._apply_full_axis(
axis, lambda df: getattr(df, func)(**kwargs)
axis, lambda df: df.apply(func, *args, **kwargs)
)
return self.__constructor__(new_modin_frame)

Expand All @@ -1942,6 +1947,7 @@ def dict_apply_builder(df, func_dict={}):
# all objects are `DataFrame`s.
return pandas.DataFrame(df.apply(func_dict, *args, **kwargs))

func = {k: wrap_udf_function(v) if callable(v) else v for k, v in func.items()}
return self.__constructor__(
self._modin_frame._apply_full_axis_select_indices(
axis, dict_apply_builder, func, keep_remaining=False
Expand Down Expand Up @@ -1969,6 +1975,7 @@ def _list_like_func(self, func, axis, *args, **kwargs):
if axis == 1
else self.columns
)
func = [wrap_udf_function(f) if callable(f) else f for f in func]
new_modin_frame = self._modin_frame._apply_full_axis(
axis,
lambda df: pandas.DataFrame(df.apply(func, axis, *args, **kwargs)),
Expand All @@ -1987,14 +1994,10 @@ def _callable_func(self, func, axis, *args, **kwargs):
Returns:
A new PandasQueryCompiler.
"""
if isinstance(pandas.DataFrame().apply(func), pandas.Series):
new_modin_frame = self._modin_frame._fold_reduce(
axis, lambda df: df.apply(func, axis=axis, *args, **kwargs)
)
else:
new_modin_frame = self._modin_frame._apply_full_axis(
axis, lambda df: df.apply(func, axis=axis, *args, **kwargs)
)
func = wrap_udf_function(func)
new_modin_frame = self._modin_frame._apply_full_axis(
axis, lambda df: df.apply(func, axis=axis, *args, **kwargs)
)
return self.__constructor__(new_modin_frame)

# END UDF
Expand Down
48 changes: 39 additions & 9 deletions modin/engines/base/frame/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _set_columns(self, new_columns):
self._dtypes.index = new_columns
self._apply_index_objs(axis=1)

def _set_axis(self, axis, new_axis):
def _set_axis(self, axis, new_axis, cache_only=False):
"""Replaces the current labels at the specified axis with the new one
Parameters
Expand All @@ -220,11 +220,20 @@ def _set_axis(self, axis, new_axis):
Axis to set labels along
new_axis : Index,
The replacement labels
cache_only : bool,
Whether to change only external indices, or propagate it
into partitions
"""
if axis:
self._set_columns(new_axis)
if not cache_only:
self._set_columns(new_axis)
else:
self._columns_cache = ensure_index(new_axis)
else:
self._set_index(new_axis)
if not cache_only:
self._set_index(new_axis)
else:
self._index_cache = ensure_index(new_axis)

columns = property(_get_columns, _set_columns)
index = property(_get_index, _set_index)
Expand Down Expand Up @@ -256,50 +265,71 @@ def _filter_empties(self):
self._column_widths_cache = [w for w in self._column_widths if w > 0]
self._row_lengths_cache = [r for r in self._row_lengths if r > 0]

def _validate_axis_equality(self, axis: int):
def _validate_axis_equality(self, axis: int, force: bool = False):
"""
Validates internal and external indices of modin_frame at the specified axis.
Parameters
----------
axis : int,
Axis to validate indices along
force : bool,
Whether to update external indices with internal if their lengths
do not match or raise an exception in that case.
"""
internal_axis = self._frame_mgr_cls.get_indices(
axis, self._partitions, lambda df: df.axes[axis]
)
is_equals = self.axes[axis].equals(internal_axis)
is_lenghts_matches = len(self.axes[axis]) == len(internal_axis)
if not is_equals:
self._set_axis(axis, self.axes[axis])
if force:
new_axis = self.axes[axis] if is_lenghts_matches else internal_axis
self._set_axis(axis, new_axis, cache_only=not is_lenghts_matches)
else:
self._set_axis(
axis, self.axes[axis],
)

def _validate_internal_indices(self, mode=None, **kwargs):
"""
Validates and optionally updates internal and external indices
of modin_frame in specified mode. There is 3 modes supported:
1. "reduced" - validates and updates indices on that axes
1. "reduced" - force validates on that axes
where external indices is ["__reduced__"]
2. "all" - validates indices at all axes, optionally updates
internal indices if `update` parameter specified in kwargs
2. "all" - validates indices at all axes, optionally force
if `force` parameter specified in kwargs
3. "custom" - validation follows arguments specified in kwargs.
Parameters
----------
mode : str or bool, default None
validate_index : bool, (optional, could be specified via `mode`)
validate_columns : bool, (optional, could be specified via `mode`)
force : bool (optional, could be specified via `mode`)
Whether to update external indices with internal if their lengths
do not match or raise an exception in that case.
"""

if isinstance(mode, bool):
is_force = mode
mode = "all"
else:
is_force = kwargs.get("force", False)

reduced_sample = pandas.Index(["__reduced__"])
args_dict = {
"custom": kwargs,
"reduced": {
"validate_index": self.index.equals(reduced_sample),
"validate_columns": self.columns.equals(reduced_sample),
"force": True,
},
"all": {
"validate_index": True,
"validate_columns": True,
"force": is_force,
},
"all": {"validate_index": True, "validate_columns": True},
}

args = args_dict.get(mode, args_dict["custom"])
Expand Down
8 changes: 6 additions & 2 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,9 @@ def apply(
axis = self._get_axis_number(axis)
ErrorMessage.non_verified_udf()
if isinstance(func, str):
result = self._query_compiler.apply(func, axis=axis, *args, **kwds)
result = self._query_compiler.apply(
func, axis=axis, raw=raw, result_type=result_type, *args, **kwds,
)
if isinstance(result, BasePandasDataset):
return result._query_compiler
return result
Expand All @@ -601,7 +603,9 @@ def apply(
)
elif not callable(func) and not is_list_like(func):
raise TypeError("{} object is not callable".format(type(func)))
query_compiler = self._query_compiler.apply(func, axis, args=args, **kwds)
query_compiler = self._query_compiler.apply(
func, axis, args=args, raw=raw, result_type=result_type, **kwds,
)
return query_compiler

def asfreq(self, freq, method=None, how=None, normalize=False, fill_value=None):
Expand Down
3 changes: 2 additions & 1 deletion modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from modin.error_message import ErrorMessage

from .utils import _inherit_docstrings
from .utils import _inherit_docstrings, wrap_udf_function
from .series import Series


Expand Down Expand Up @@ -644,6 +644,7 @@ def _apply_agg_function(self, f, drop=True, *args, **kwargs):
"""
assert callable(f), "'{0}' object is not callable".format(type(f))

f = wrap_udf_function(f)
if self._is_multi_by:
return self._default_to_pandas(f, *args, **kwargs)

Expand Down
12 changes: 12 additions & 0 deletions modin/pandas/test/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
create_test_dfs,
test_data_small_values,
test_data_small_keys,
udf_func_values,
udf_func_keys,
)

pd.DEFAULT_NPARTITIONS = 4
Expand Down Expand Up @@ -1812,6 +1814,16 @@ def test_apply_numeric(self, request, data, func, axis):
pandas_result = pandas_df.apply(lambda df: df.drop(key), axis=1)
df_equals(modin_result, pandas_result)

@pytest.mark.parametrize("func", udf_func_values, ids=udf_func_keys)
@pytest.mark.parametrize("data", test_data_values, ids=test_data_keys)
def test_apply_udf(self, data, func):
eval_general(
*create_test_dfs(data),
lambda df, *args, **kwargs: df.apply(*args, **kwargs),
func=func,
other=lambda df: df,
)

def test_eval_df_use_case(self):
frame_data = {"a": random_state.randn(10), "b": random_state.randn(10)}
df = pandas.DataFrame(frame_data)
Expand Down
14 changes: 14 additions & 0 deletions modin/pandas/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,20 @@

numeric_agg_funcs = ["sum mean", "sum sum", "sum df sum"]

udf_func = {
"return self": lambda df: lambda x, *args, **kwargs: type(x)(x.values),
"change index": lambda df: lambda x, *args, **kwargs: pandas.Series(
x.values, index=np.arange(-1, len(x.index) - 1)
),
"return none": lambda df: lambda x, *args, **kwargs: None,
"return empty": lambda df: lambda x, *args, **kwargs: pandas.Series(),
"access self": lambda df: lambda x, other, *args, **kwargs: pandas.Series(
x.values, index=other.index
),
}
udf_func_keys = list(udf_func.keys())
udf_func_values = list(udf_func.values())

# Test q values for quantiles
quantiles = {
"0.25": 0.25,
Expand Down
11 changes: 11 additions & 0 deletions modin/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,14 @@ def try_cast_to_pandas(obj):
else getattr(pandas.Series, fn_name, obj)
)
return obj


def wrap_udf_function(func):
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
# if user accidently returns modin DataFrame or Series
# casting it back to pandas to properly process
return try_cast_to_pandas(result)

wrapper.__name__ = func.__name__
return wrapper

0 comments on commit 0e826ac

Please sign in to comment.