Skip to content

Commit

Permalink
FEAT-modin-project#7039: pass scalar dtype as is to astype query comp…
Browse files Browse the repository at this point in the history
…iler

Signed-off-by: arunjose696 <arunjose696@gmail.com>
  • Loading branch information
arunjose696 committed Apr 5, 2024
1 parent 23c1ec0 commit 98107c2
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 34 deletions.
66 changes: 44 additions & 22 deletions modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,7 @@ def astype(self, col_dtypes, errors: str = "raise"):
Parameters
----------
col_dtypes : dictionary of {col: dtype,...}
col_dtypes : dictionary of {col: dtype,...} or str
Where col is the column name and dtype is a NumPy dtype.
errors : {'raise', 'ignore'}, default: 'raise'
Control raising of exceptions on invalid data for provided dtype.
Expand All @@ -1642,39 +1642,61 @@ def astype(self, col_dtypes, errors: str = "raise"):
# will store the encoded table. That can lead to higher memory footprint.
# TODO: Revisit if this hurts users.
use_full_axis_cast = False
for column, dtype in col_dtypes.items():
if not is_dtype_equal(dtype, self_dtypes[column]):
if new_dtypes is None:
new_dtypes = self_dtypes.copy()
# Update the new dtype series to the proper pandas dtype
new_dtype = pandas.api.types.pandas_dtype(dtype)
if Engine.get() == "Dask" and hasattr(dtype, "_is_materialized"):
# FIXME: https://github.com/dask/distributed/issues/8585
_ = dtype._materialize_categories()

# We cannot infer without computing the dtype if
if isinstance(col_dtypes, dict):
for column, dtype in col_dtypes.items():
if not is_dtype_equal(dtype, self_dtypes[column]):
if new_dtypes is None:
new_dtypes = self_dtypes.copy()
# Update the new dtype series to the proper pandas dtype
new_dtype = pandas.api.types.pandas_dtype(dtype)
if Engine.get() == "Dask" and hasattr(dtype, "_is_materialized"):
# FIXME: https://github.com/dask/distributed/issues/8585
_ = dtype._materialize_categories()

# We cannot infer without computing the dtype if
if isinstance(new_dtype, pandas.CategoricalDtype):
new_dtypes[column] = LazyProxyCategoricalDtype._build_proxy(
# Actual parent will substitute `None` at `.set_dtypes_cache`
parent=None,
column_name=column,
materializer=lambda parent, column: parent._compute_dtypes(
columns=[column]
)[column],
)
use_full_axis_cast = True
else:
new_dtypes[column] = new_dtype

def astype_builder(df):
"""Compute new partition frame with dtypes updated."""
return df.astype(
{k: v for k, v in col_dtypes.items() if k in df}, errors=errors
)

else:
# Assume that the dtype is a scalar.
if not np.all(col_dtypes == self_dtypes):
new_dtypes = self_dtypes.copy()
new_dtype = pandas.api.types.pandas_dtype(col_dtypes)
if isinstance(new_dtype, pandas.CategoricalDtype):
new_dtypes[column] = LazyProxyCategoricalDtype._build_proxy(
new_dtypes[:] = LazyProxyCategoricalDtype._build_proxy(
# Actual parent will substitute `None` at `.set_dtypes_cache`
parent=None,
column_name=column,
column_name=new_dtypes.index,
materializer=lambda parent, column: parent._compute_dtypes(
columns=[column]
)[column],
)
use_full_axis_cast = True
else:
new_dtypes[column] = new_dtype
new_dtypes[:] = new_dtype

def astype_builder(df):
"""Compute new partition frame with dtypes updated."""
return df.astype(col_dtypes, errors=errors)

if new_dtypes is None:
return self.copy()

def astype_builder(df):
"""Compute new partition frame with dtypes updated."""
return df.astype(
{k: v for k, v in col_dtypes.items() if k in df}, errors=errors
)

if use_full_axis_cast:
new_frame = self._partition_mgr_cls.map_axis_partitions(
0, self._partitions, astype_builder, keep_partitioning=True
Expand Down
2 changes: 1 addition & 1 deletion modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,7 +1845,7 @@ def astype(self, col_dtypes, errors: str = "raise"): # noqa: PR02
Parameters
----------
col_dtypes : dict
col_dtypes : dict or str
Map for column names and new dtypes.
errors : {'raise', 'ignore'}, default: 'raise'
Control raising of exceptions on invalid data for provided dtype.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ def astype(self, col_dtypes, **kwargs):
Parameters
----------
col_dtypes : dict
col_dtypes : dict or str
Maps column names to new data types.
**kwargs : dict
Keyword args. Not used.
Expand All @@ -1046,6 +1046,8 @@ def astype(self, col_dtypes, **kwargs):
HdkOnNativeDataframe
The new frame.
"""
if not isinstance(col_dtypes, dict):
col_dtypes = {column: col_dtypes for column in self.columns}
columns = col_dtypes.keys()
new_dtypes = self.copy_dtypes_cache()
for column in columns:
Expand Down
18 changes: 9 additions & 9 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ def astype(self, dtype, copy=None, errors="raise"): # noqa: PR01, RT01, D200
"""
if copy is None:
copy = True
# dtype can be a series, a dict, or a scalar. If it's series or scalar,
# dtype can be a series, a dict, or a scalar. If it's series.
# convert it to a dict before passing it to the query compiler.
if isinstance(dtype, (pd.Series, pandas.Series)):
if not dtype.index.is_unique:
Expand All @@ -1022,24 +1022,24 @@ def astype(self, dtype, copy=None, errors="raise"): # noqa: PR01, RT01, D200
"Only a column name can be used for the key in "
+ "a dtype mappings argument."
)
col_dtypes = dtype
else:
# Assume that the dtype is a scalar.
col_dtypes = {column: dtype for column in self._query_compiler.columns}

if not copy:
# If the new types match the old ones, then copying can be avoided
if self._query_compiler._modin_frame.has_materialized_dtypes:
frame_dtypes = self._query_compiler._modin_frame.dtypes
for col in col_dtypes:
if col_dtypes[col] != frame_dtypes[col]:
if isinstance(dtype, dict):
for col in dtype:
if dtype[col] != frame_dtypes[col]:
copy = True
break
else:
if not np.all(frame_dtypes == dtype):
copy = True
break
else:
copy = True

if copy:
new_query_compiler = self._query_compiler.astype(col_dtypes, errors=errors)
new_query_compiler = self._query_compiler.astype(dtype, errors=errors)
return self._create_or_update_from_compiler(new_query_compiler)
return self

Expand Down
4 changes: 3 additions & 1 deletion modin/pandas/test/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,9 @@ def test_astype(data, request):
expected_exception = None
if "float_nan_data" in request.node.callspec.id:
# FIXME: https://github.com/modin-project/modin/issues/7039
expected_exception = False
expected_exception = pandas.errors.IntCastingNaNError(
"Cannot convert non-finite values (NA or inf) to integer"
)
eval_general(
modin_series,
pandas_series,
Expand Down

0 comments on commit 98107c2

Please sign in to comment.