Skip to content

Commit 13fa571

Browse files
committed
[feat] Transplant has_object_columns from reg-cocktail
1 parent 048656e commit 13fa571

File tree

2 files changed

+56
-27
lines changed

2 files changed

+56
-27
lines changed

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import functools
22
from logging import Logger
3-
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union, cast
3+
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union, cast
44

55
import numpy as np
66

@@ -21,6 +21,7 @@
2121
from autoPyTorch.data.utils import (
2222
DatasetCompressionInputType,
2323
DatasetDTypeContainerType,
24+
has_object_columns,
2425
reduce_dataset_size_if_too_large
2526
)
2627
from autoPyTorch.utils.common import ispandas
@@ -105,9 +106,10 @@ def __init__(
105106
logger: Optional[Union[PicklableClientLogger, Logger]] = None,
106107
dataset_compression: Optional[Mapping[str, Any]] = None,
107108
) -> None:
109+
super().__init__(logger)
108110
self._dataset_compression = dataset_compression
109111
self._reduced_dtype: Optional[DatasetDTypeContainerType] = None
110-
super().__init__(logger)
112+
self.all_nan_columns: Optional[Set[str]] = None
111113

112114
@staticmethod
113115
def _comparator(cmp1: str, cmp2: str) -> int:
@@ -132,6 +134,41 @@ def _comparator(cmp1: str, cmp2: str) -> int:
132134
idx1, idx2 = choices.index(cmp1), choices.index(cmp2)
133135
return idx1 - idx2
134136

137+
def _convert_all_nan_columns_to_numeric(self, X: pd.DataFrame, fit: bool = False) -> pd.DataFrame:
138+
"""
139+
Convert columns whose values were all nan in the training dataset to numeric.
140+
141+
Args:
142+
X (pd.DataFrame):
143+
The data to transform.
144+
fit (bool):
145+
Whether this call is the fit to X or the transform using pre-fitted transformer.
146+
"""
147+
if not fit and self.all_nan_columns is None:
148+
raise ValueError('_fit must be called before calling transform')
149+
150+
if fit:
151+
all_nan_columns = X.columns[X.isna().all()]
152+
else:
153+
assert self.all_nan_columns is not None
154+
all_nan_columns = list(self.all_nan_columns)
155+
156+
for col in all_nan_columns:
157+
X[col] = np.nan
158+
X[col] = pd.to_numeric(X[col])
159+
if fit and len(self.dtypes):
160+
self.dtypes[list(X.columns).index(col)] = X[col].dtype
161+
162+
if has_object_columns(X.dtypes.values):
163+
X = self.infer_objects(X)
164+
165+
if fit:
166+
# TODO: Check how to integrate below
167+
# self.dtypes = [dt.name for dt in X.dtypes]
168+
self.all_nan_columns = set(all_nan_columns)
169+
170+
return X
171+
135172
def _fit(
136173
self,
137174
X: SupportedFeatTypes,
@@ -158,22 +195,7 @@ def _fit(
158195

159196
if ispandas(X) and not issparse(X):
160197
X = cast(pd.DataFrame, X)
161-
# Treat a column with all instances a NaN as numerical
162-
# This will prevent doing encoding to a categorical column made completely
163-
# out of nan values -- which will trigger a fail, as encoding is not supported
164-
# with nan values.
165-
# Columns that are completely made of NaN values are provided to the pipeline
166-
# so that later stages decide how to handle them
167-
if np.any(pd.isnull(X)):
168-
for column in X.columns:
169-
if X[column].isna().all():
170-
X[column] = pd.to_numeric(X[column])
171-
# Also note this change in self.dtypes
172-
if len(self.dtypes) != 0:
173-
self.dtypes[list(X.columns).index(column)] = X[column].dtype
174-
175-
if not X.select_dtypes(include='object').empty:
176-
X = self.infer_objects(X)
198+
X = self._convert_all_nan_columns_to_numeric(X, fit=True)
177199

178200
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)
179201

@@ -247,14 +269,7 @@ def transform(
247269
X = self.numpy_array_to_pandas(X)
248270

249271
if ispandas(X) and not issparse(X):
250-
if np.any(pd.isnull(X)):
251-
for column in X.columns:
252-
if X[column].isna().all():
253-
X[column] = pd.to_numeric(X[column])
254-
255-
# Also remove the object dtype for new data
256-
if not X.select_dtypes(include='object').empty:
257-
X = self.infer_objects(X)
272+
X = self._convert_all_nan_columns_to_numeric(X)
258273

259274
# Check the data here so we catch problems on new test data
260275
self._check_data(X)
@@ -369,7 +384,7 @@ def _check_data(
369384
X = cast(pd.DataFrame, X)
370385

371386
# Handle objects if possible
372-
if not X.select_dtypes(include='object').empty:
387+
if has_object_columns(X.dtypes.values):
373388
X = self.infer_objects(X)
374389

375390
# Define the column to be encoded here as the feature validator is fitted once

autoPyTorch/data/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@
3939
}
4040

4141

42+
def has_object_columns(feature_types: pd.Series) -> bool:
43+
"""
44+
Indicate whether on a Series of dtypes for a Pandas DataFrame
45+
there exists one or more object columns.
46+
Args:
47+
feature_types (pd.Series): The feature types for a DataFrame.
48+
Returns:
49+
bool:
50+
True if the DataFrame dtypes contain an object column, False
51+
otherwise.
52+
"""
53+
return np.dtype('O') in feature_types
54+
55+
4256
def get_dataset_compression_mapping(
4357
memory_limit: int,
4458
dataset_compression: Union[bool, Mapping[str, Any]]

0 commit comments

Comments
 (0)