Skip to content

Commit

Permalink
FIX-modin-project#7039: Pass scalar dtype as is to astype query compi…
Browse files Browse the repository at this point in the history
…ler (modin-project#7152)

Co-authored-by: Iaroslav Igoshev <Poolliver868@mail.ru>
Co-authored-by: Anatoly Myachev <anatoliimyachev@mail.com>
Signed-off-by: arunjose696 <arunjose696@gmail.com>
  • Loading branch information
3 people authored Apr 11, 2024
1 parent 0755a61 commit ad057fa
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 40 deletions.
81 changes: 54 additions & 27 deletions modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,7 @@ def astype(self, col_dtypes, errors: str = "raise"):
Parameters
----------
col_dtypes : dictionary of {col: dtype,...}
col_dtypes : dictionary of {col: dtype,...} or str
Where col is the column name and dtype is a NumPy dtype.
errors : {'raise', 'ignore'}, default: 'raise'
Control raising of exceptions on invalid data for provided dtype.
Expand All @@ -1642,39 +1642,66 @@ def astype(self, col_dtypes, errors: str = "raise"):
# will store the encoded table. That can lead to higher memory footprint.
# TODO: Revisit if this hurts users.
use_full_axis_cast = False
for column, dtype in col_dtypes.items():
if not is_dtype_equal(dtype, self_dtypes[column]):
if new_dtypes is None:
new_dtypes = self_dtypes.copy()
# Update the new dtype series to the proper pandas dtype
new_dtype = pandas.api.types.pandas_dtype(dtype)
if Engine.get() == "Dask" and hasattr(dtype, "_is_materialized"):
# FIXME: https://github.com/dask/distributed/issues/8585
_ = dtype._materialize_categories()
if isinstance(col_dtypes, dict):
for column, dtype in col_dtypes.items():
if not is_dtype_equal(dtype, self_dtypes[column]):
if new_dtypes is None:
new_dtypes = self_dtypes.copy()
# Update the new dtype series to the proper pandas dtype
new_dtype = pandas.api.types.pandas_dtype(dtype)
if Engine.get() == "Dask" and hasattr(dtype, "_is_materialized"):
# FIXME: https://github.com/dask/distributed/issues/8585
_ = dtype._materialize_categories()

# We cannot infer without computing the dtype if new dtype is categorical
if isinstance(new_dtype, pandas.CategoricalDtype):
new_dtypes[column] = LazyProxyCategoricalDtype._build_proxy(
# Actual parent will substitute `None` at `.set_dtypes_cache`
parent=None,
column_name=column,
materializer=lambda parent, column: parent._compute_dtypes(
columns=[column]
)[column],
)
use_full_axis_cast = True
else:
new_dtypes[column] = new_dtype

# We cannot infer without computing the dtype if
def astype_builder(df):
"""Compute new partition frame with dtypes updated."""
return df.astype(
{k: v for k, v in col_dtypes.items() if k in df}, errors=errors
)

else:
# Assume that the dtype is a scalar.
if not (col_dtypes == self_dtypes).all():
new_dtypes = self_dtypes.copy()
new_dtype = pandas.api.types.pandas_dtype(col_dtypes)
if Engine.get() == "Dask" and hasattr(new_dtype, "_is_materialized"):
# FIXME: https://github.com/dask/distributed/issues/8585
_ = new_dtype._materialize_categories()
if isinstance(new_dtype, pandas.CategoricalDtype):
new_dtypes[column] = LazyProxyCategoricalDtype._build_proxy(
# Actual parent will substitute `None` at `.set_dtypes_cache`
parent=None,
column_name=column,
materializer=lambda parent, column: parent._compute_dtypes(
columns=[column]
)[column],
)
new_dtypes[:] = new_dtypes.to_frame().apply(
lambda column: LazyProxyCategoricalDtype._build_proxy(
# Actual parent will substitute `None` at `.set_dtypes_cache`
parent=None,
column_name=column.index[0],
materializer=lambda parent, column: parent._compute_dtypes(
columns=[column]
)[column],
)
)[0]
use_full_axis_cast = True
else:
new_dtypes[column] = new_dtype
new_dtypes[:] = new_dtype

def astype_builder(df):
"""Compute new partition frame with dtypes updated."""
return df.astype(col_dtypes, errors=errors)

if new_dtypes is None:
return self.copy()

def astype_builder(df):
"""Compute new partition frame with dtypes updated."""
return df.astype(
{k: v for k, v in col_dtypes.items() if k in df}, errors=errors
)

if use_full_axis_cast:
new_frame = self._partition_mgr_cls.map_axis_partitions(
0, self._partitions, astype_builder, keep_partitioning=True
Expand Down
2 changes: 1 addition & 1 deletion modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,7 +1860,7 @@ def astype(self, col_dtypes, errors: str = "raise"): # noqa: PR02
Parameters
----------
col_dtypes : dict
col_dtypes : dict or str
Map for column names and new dtypes.
errors : {'raise', 'ignore'}, default: 'raise'
Control raising of exceptions on invalid data for provided dtype.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,7 @@ def astype(self, col_dtypes, **kwargs):
Parameters
----------
col_dtypes : dict
col_dtypes : dict or str
Maps column names to new data types.
**kwargs : dict
Keyword args. Not used.
Expand All @@ -1051,6 +1051,8 @@ def astype(self, col_dtypes, **kwargs):
HdkOnNativeDataframe
The new frame.
"""
if not isinstance(col_dtypes, dict):
col_dtypes = {column: col_dtypes for column in self.columns}
columns = col_dtypes.keys()
new_dtypes = self.copy_dtypes_cache()
for column in columns:
Expand Down
18 changes: 9 additions & 9 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def astype(self, dtype, copy=None, errors="raise"): # noqa: PR01, RT01, D200
"""
if copy is None:
copy = True
# dtype can be a series, a dict, or a scalar. If it's series or scalar,
# dtype can be a series, a dict, or a scalar. If it's series,
# convert it to a dict before passing it to the query compiler.
if isinstance(dtype, (pd.Series, pandas.Series)):
if not dtype.index.is_unique:
Expand All @@ -1026,24 +1026,24 @@ def astype(self, dtype, copy=None, errors="raise"): # noqa: PR01, RT01, D200
"Only a column name can be used for the key in "
+ "a dtype mappings argument."
)
col_dtypes = dtype
else:
# Assume that the dtype is a scalar.
col_dtypes = {column: dtype for column in self._query_compiler.columns}

if not copy:
# If the new types match the old ones, then copying can be avoided
if self._query_compiler._modin_frame.has_materialized_dtypes:
frame_dtypes = self._query_compiler._modin_frame.dtypes
for col in col_dtypes:
if col_dtypes[col] != frame_dtypes[col]:
if isinstance(dtype, dict):
for col in dtype:
if dtype[col] != frame_dtypes[col]:
copy = True
break
else:
if not (frame_dtypes == dtype).all():
copy = True
break
else:
copy = True

if copy:
new_query_compiler = self._query_compiler.astype(col_dtypes, errors=errors)
new_query_compiler = self._query_compiler.astype(dtype, errors=errors)
return self._create_or_update_from_compiler(new_query_compiler)
return self

Expand Down
5 changes: 3 additions & 2 deletions modin/tests/pandas/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,8 +1079,9 @@ def test_astype(data, request):
eval_general(modin_series, pandas_series, lambda df: df.astype(str))
expected_exception = None
if "float_nan_data" in request.node.callspec.id:
# FIXME: https://github.com/modin-project/modin/issues/7039
expected_exception = False
expected_exception = pd.errors.IntCastingNaNError(
"Cannot convert non-finite values (NA or inf) to integer"
)
eval_general(
modin_series,
pandas_series,
Expand Down

0 comments on commit ad057fa

Please sign in to comment.