Skip to content

Commit 9a7ce79

Browse files
committed
[fix] Fix mypy issues
1 parent bd2a73b commit 9a7ce79

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from autoPyTorch.data.utils import (
2222
DatasetCompressionInputType,
2323
DatasetDTypeContainerType,
24+
ispandas,
2425
reduce_dataset_size_if_too_large
2526
)
2627
from autoPyTorch.utils.common import autoPyTorchEnum
@@ -211,7 +212,7 @@ def _fit(self, X: SupportedFeatTypes) -> BaseEstimator:
211212
if isinstance(X, np.ndarray):
212213
X = self.numpy_to_pandas(X)
213214

214-
if hasattr(X, "iloc") and not issparse(X):
215+
if ispandas(X) and not issparse(X):
215216
X = cast(pd.DataFrame, X)
216217
X = self._convert_all_nan_columns_to_numeric(X, fit=True)
217218
self.enc_columns, self.feat_type = self._get_columns_to_encode(X)
@@ -328,7 +329,7 @@ def transform(self, X: SupportedFeatTypes) -> Union[np.ndarray, spmatrix, pd.Dat
328329

329330
# If a list was provided, it will be converted to pandas
330331
X = self.list_to_pandas(X) if isinstance(X, list) else self.numpy_to_pandas(X)
331-
if hasattr(X, "iloc") and not issparse(X):
332+
if ispandas(X) and not issparse(X):
332333
X = self._convert_all_nan_columns_to_numeric(X)
333334
if len(self.categorical_columns) > 0:
334335
X = self._adapt_categorical_columns_to_train_data(X)
@@ -375,7 +376,7 @@ def _compress_dataset(self, X: DatasetCompressionInputType) -> DatasetCompressio
375376
DatasetCompressionInputType:
376377
Compressed dataset.
377378
"""
378-
is_dataframe = hasattr(X, 'iloc')
379+
is_dataframe = ispandas(X)
379380
is_reducible_type = isinstance(X, np.ndarray) or issparse(X) or is_dataframe
380381
if not is_reducible_type or self._dataset_compression is None:
381382
return X
@@ -431,17 +432,16 @@ def _check_data(self, X: SupportedFeatTypes) -> None:
431432
f"but got type {str(type(X))} in the current features. This change might cause problems"
432433
)
433434

434-
# Do not support category/string numpy data. Only numbers
435-
if hasattr(X, "dtype") and not np.issubdtype(X.dtype.type, np.number): # type: ignore[union-attr]
435+
if ispandas(X): # For pandas, no support of nan in categorical cols
436+
self._check_dataframe(X)
437+
438+
# For ndarray, no support of category/string
439+
if isinstance(X, np.ndarray) and not np.issubdtype(X.dtype.type, np.number):
440+
dt = X.dtype.type
436441
raise ValueError(
437-
"AutoPyTorch does not support numpy.ndarray with non-numerical dtype, "
438-
f"but got {X.dtype.type}" # type: ignore[union-attr]
442+
f"AutoPyTorch does not support numpy.ndarray with non-numerical dtype, but got {dt}"
439443
)
440444

441-
# Then for Pandas, we do not support Nan in categorical columns
442-
if hasattr(X, "iloc"):
443-
self._check_dataframe(X)
444-
445445
def _get_columns_to_encode(
446446
self,
447447
X: pd.DataFrame,

autoPyTorch/data/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
}
3838

3939

40+
def ispandas(X: Any) -> bool:
41+
""" Whether X is pandas.DataFrame or pandas.Series """
42+
return hasattr(X, "iloc")
43+
44+
4045
def get_dataset_compression_mapping(
4146
memory_limit: int,
4247
dataset_compression: Union[bool, Mapping[str, Any]]

0 commit comments

Comments
 (0)