Skip to content

Commit c3e0fa0

Browse files
committed
[refactor] Separate convert all nan columns to numeric
1 parent 080fe95 commit c3e0fa0

File tree

2 files changed

+56
-44
lines changed

2 files changed

+56
-44
lines changed

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,42 @@ def _comparator(cmp1: str, cmp2: str) -> int:
154154
idx1, idx2 = choices.index(cmp1), choices.index(cmp2)
155155
return idx1 - idx2
156156

157-
def _fit(
158-
self,
159-
X: SupportedFeatTypes,
160-
) -> BaseEstimator:
157+
def _convert_all_nan_columns_to_numeric(self, X: pd.DataFrame, fit: bool = False) -> pd.DataFrame:
158+
"""
159+
Convert columns whose values were all nan in the training dataset to numeric.
160+
161+
Args:
162+
X (pd.DataFrame):
163+
The data to transform.
164+
fit (bool):
165+
Whether this call is the fit to X or the transform using pre-fitted transformer.
166+
"""
167+
if not fit and self.all_nan_columns is None:
168+
raise ValueError('_fit must be called before calling transform')
169+
170+
if fit:
171+
all_nan_columns = X.columns[X.isna().all()]
172+
else:
173+
assert self.all_nan_columns is not None
174+
all_nan_columns = list(self.all_nan_columns)
175+
176+
for col in all_nan_columns:
177+
X[col] = np.nan
178+
X[col] = pd.to_numeric(X[col])
179+
if len(self.dtypes):
180+
self.dtypes[list(X.columns).index(col)] = X[col].dtype
181+
182+
if has_object_columns(X.dtypes.values):
183+
X = self.infer_objects(X)
184+
185+
if fit:
186+
# TODO: Check how to integrate below
187+
# self.dtypes = [dt.name for dt in X.dtypes]
188+
self.all_nan_columns = set(all_nan_columns)
189+
190+
return X
191+
192+
def _fit(self, X: SupportedFeatTypes) -> BaseEstimator:
161193
"""
162194
In case input data is a pandas DataFrame, this utility encodes the user provided
163195
features (from categorical for example) to a numerical value that further stages
@@ -180,23 +212,7 @@ def _fit(
180212

181213
if ispandas(X) and not issparse(X):
182214
X = cast(pd.DataFrame, X)
183-
# Treat a column with all instances a NaN as numerical
184-
# This will prevent doing encoding to a categorical column made completely
185-
# out of nan values -- which will trigger a fail, as encoding is not supported
186-
# with nan values.
187-
# Columns that are completely made of NaN values are provided to the pipeline
188-
# so that later stages decide how to handle them
189-
if np.any(pd.isnull(X)):
190-
for column in X.columns:
191-
if X[column].isna().all():
192-
X[column] = pd.to_numeric(X[column])
193-
# Also note this change in self.dtypes
194-
if len(self.dtypes) != 0:
195-
self.dtypes[list(X.columns).index(column)] = X[column].dtype
196-
197-
if not X.select_dtypes(include='object').empty:
198-
X = self.infer_objects(X)
199-
215+
X = self._convert_all_nan_columns_to_numeric(X, fit=True)
200216
self.enc_columns, self.feat_type = self._get_columns_to_encode(X)
201217

202218
assert self.feat_type is not None
@@ -241,10 +257,7 @@ def _fit(
241257
self.num_features = np.shape(X)[1]
242258
return self
243259

244-
def transform(
245-
self,
246-
X: SupportedFeatTypes,
247-
) -> Union[np.ndarray, spmatrix, pd.DataFrame]:
260+
def transform(self, X: SupportedFeatTypes) -> Union[np.ndarray, spmatrix, pd.DataFrame]:
248261
"""
249262
Validates and fit a categorical encoder (if needed) to the features.
250263
The supported data types are List, numpy arrays and pandas DataFrames.
@@ -264,19 +277,11 @@ def transform(
264277
# If a list was provided, it will be converted to pandas
265278
if isinstance(X, list):
266279
X = self.list_to_pandas(X)
267-
268-
if isinstance(X, np.ndarray):
280+
elif isinstance(X, np.ndarray):
269281
X = self.numpy_to_pandas(X)
270282

271283
if ispandas(X) and not issparse(X):
272-
if np.any(pd.isnull(X)):
273-
for column in X.columns:
274-
if X[column].isna().all():
275-
X[column] = pd.to_numeric(X[column])
276-
277-
# Also remove the object dtype for new data
278-
if not X.select_dtypes(include='object').empty:
279-
X = self.infer_objects(X)
284+
X = self._convert_all_nan_columns_to_numeric(X)
280285

281286
# Check the data here so we catch problems on new test data
282287
self._check_data(X)
@@ -344,10 +349,7 @@ def _compress_dataset(self, X: DatasetCompressionInputType) -> DatasetCompressio
344349
self._reduced_dtype = dict(X.dtypes) if is_dataframe else X.dtype
345350
return X
346351

347-
def _check_data(
348-
self,
349-
X: SupportedFeatTypes,
350-
) -> None:
352+
def _check_data(self, X: SupportedFeatTypes) -> None:
351353
"""
352354
Feature dimensionality and data type checks
353355
@@ -481,10 +483,7 @@ def list_to_pandas(self, X: SupportedFeatTypes) -> pd.DataFrame:
481483
)
482484
return X
483485

484-
def numpy_to_pandas(
485-
self,
486-
X: np.ndarray,
487-
) -> pd.DataFrame:
486+
def numpy_to_pandas(self, X: np.ndarray) -> pd.DataFrame:
488487
"""
489488
Converts a numpy array to pandas for type inference
490489
@@ -533,3 +532,17 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
533532
self.object_dtype_mapping = {column: X[column].dtype for column in X.columns}
534533
self.logger.debug(f"Infer Objects: {self.object_dtype_mapping}")
535534
return X
535+
536+
537+
def has_object_columns(feature_types: pd.Series) -> bool:
538+
"""
539+
Indicate whether on a Series of dtypes for a Pandas DataFrame
540+
there exists one or more object columns.
541+
Args:
542+
feature_types (pd.Series): The feature types for a DataFrame.
543+
Returns:
544+
bool:
545+
True if the DataFrame dtypes contain an object column, False
546+
otherwise.
547+
"""
548+
return np.dtype('O') in feature_types

test/test_data/test_feature_validator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,6 @@ def test_features_unsupported_calls_are_raised():
328328
expected
329329
"""
330330
validator = TabularFeatureValidator()
331-
#with pytest.raises(TypeError, match=r"invalid type `time and/or date datatype`."):
332331
with pytest.raises(TypeError, match=r"invalid type `time and/or date datatype`."):
333332
validator.fit(pd.DataFrame({'datetime': [pd.Timestamp('20180310')]}))
334333
with pytest.raises(ValueError, match=r"AutoPyTorch only supports.*yet, the provided input"):

0 commit comments

Comments
 (0)