Skip to content

Commit ffc1620

Browse files
authored
handle nans in categorical columns (#118)
* handle nans in categorical columns * Fixed error in self dtypes * Addressed comments from francisco * Forgot to commit * Fix flake
1 parent 5adc607 commit ffc1620

File tree

2 files changed

+51
-41
lines changed

2 files changed

+51
-41
lines changed

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import sklearn.utils
1212
from sklearn import preprocessing
1313
from sklearn.base import BaseEstimator
14-
from sklearn.compose import make_column_transformer
14+
from sklearn.compose import ColumnTransformer
1515
from sklearn.exceptions import NotFittedError
1616

1717
from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SUPPORTED_FEAT_TYPES
@@ -53,16 +53,34 @@ def _fit(
5353
for column in X.columns:
5454
if X[column].isna().all():
5555
X[column] = pd.to_numeric(X[column])
56+
# Also note this change in self.dtypes
57+
if len(self.dtypes) != 0:
58+
self.dtypes[list(X.columns).index(column)] = X[column].dtype
5659

5760
self.enc_columns, self.feat_type = self._get_columns_to_encode(X)
5861

5962
if len(self.enc_columns) > 0:
60-
61-
self.encoder = make_column_transformer(
62-
(preprocessing.OrdinalEncoder(
63-
handle_unknown='use_encoded_value',
64-
unknown_value=-1,
65-
), self.enc_columns),
63+
# impute missing values before encoding,
64+
# remove once sklearn natively supports
65+
# it in ordinal encoding. Sklearn issue:
66+
# "https://github.com/scikit-learn/scikit-learn/issues/17123)"
67+
for column in self.enc_columns:
68+
if X[column].isna().any():
69+
missing_value: typing.Union[int, str] = -1
70+
# make sure for a string column we give
71+
# string missing value else we give numeric
72+
if type(X[column][0]) == str:
73+
missing_value = str(missing_value)
74+
X[column] = X[column].cat.add_categories([missing_value])
75+
X[column] = X[column].fillna(missing_value)
76+
77+
self.encoder = ColumnTransformer(
78+
[
79+
("encoder",
80+
preprocessing.OrdinalEncoder(
81+
handle_unknown='use_encoded_value',
82+
unknown_value=-1,
83+
), self.enc_columns)],
6684
remainder="passthrough"
6785
)
6886

@@ -85,6 +103,7 @@ def comparator(cmp1: str, cmp2: str) -> int:
85103
return 1
86104
else:
87105
raise ValueError((cmp1, cmp2))
106+
88107
self.feat_type = sorted(
89108
self.feat_type,
90109
key=functools.cmp_to_key(comparator)
@@ -182,9 +201,8 @@ def _check_data(
182201
if not isinstance(X, (np.ndarray, pd.DataFrame)) and not scipy.sparse.issparse(X):
183202
raise ValueError("AutoPyTorch only supports Numpy arrays, Pandas DataFrames,"
184203
" scipy sparse and Python Lists, yet, the provided input is"
185-
" of type {}".format(
186-
type(X)
187-
))
204+
" of type {}".format(type(X))
205+
)
188206

189207
if self.data_type is None:
190208
self.data_type = type(X)
@@ -217,39 +235,25 @@ def _check_data(
217235
# per estimator
218236
enc_columns, _ = self._get_columns_to_encode(X)
219237

220-
if len(enc_columns) > 0:
221-
if np.any(pd.isnull(
222-
X[enc_columns].dropna( # type: ignore[call-overload]
223-
axis='columns', how='all')
224-
)):
225-
# Ignore all NaN columns, and if still a NaN
226-
# Error out
227-
raise ValueError("Categorical features in a dataframe cannot contain "
228-
"missing/NaN values. The OrdinalEncoder used by "
229-
"AutoPyTorch cannot handle this yet (due to a "
230-
"limitation on scikit-learn being addressed via: "
231-
"https://github.com/scikit-learn/scikit-learn/issues/17123)"
232-
)
233238
column_order = [column for column in X.columns]
234239
if len(self.column_order) > 0:
235240
if self.column_order != column_order:
236241
raise ValueError("Changing the column order of the features after fit() is "
237242
"not supported. Fit() method was called with "
238-
"{} whereas the new features have {} as type".format(
239-
self.column_order,
240-
column_order,
241-
))
243+
"{} whereas the new features have {} as type".format(self.column_order,
244+
column_order,)
245+
)
242246
else:
243247
self.column_order = column_order
244248
dtypes = [dtype.name for dtype in X.dtypes]
245249
if len(self.dtypes) > 0:
246250
if self.dtypes != dtypes:
247251
raise ValueError("Changing the dtype of the features after fit() is "
248252
"not supported. Fit() method was called with "
249-
"{} whereas the new features have {} as type".format(
250-
self.dtypes,
251-
dtypes,
252-
))
253+
"{} whereas the new features have {} as type".format(self.dtypes,
254+
dtypes,
255+
)
256+
)
253257
else:
254258
self.dtypes = dtypes
255259

@@ -294,7 +298,8 @@ def _get_columns_to_encode(
294298
"pandas.Series.astype ."
295299
"If working with string objects, the following "
296300
"tutorial illustrates how to work with text data: "
297-
"https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format( # noqa: E501
301+
"https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format(
302+
# noqa: E501
298303
column,
299304
)
300305
)
@@ -349,15 +354,13 @@ def list_to_dataframe(
349354
# If a list was provided, it will be converted to pandas
350355
X_train = pd.DataFrame(data=X_train).infer_objects()
351356
self.logger.warning("The provided feature types to AutoPyTorch are of type list."
352-
"Features have been interpreted as: {}".format(
353-
[(col, t) for col, t in zip(X_train.columns, X_train.dtypes)]
354-
))
357+
"Features have been interpreted as: {}".format([(col, t) for col, t in
358+
zip(X_train.columns, X_train.dtypes)]))
355359
if X_test is not None:
356360
if not isinstance(X_test, list):
357361
self.logger.warning("Train features are a list while the provided test data"
358-
"is {}. X_test will be casted as DataFrame.".format(
359-
type(X_test)
360-
))
362+
"is {}. X_test will be casted as DataFrame.".format(type(X_test))
363+
)
361364
X_test = pd.DataFrame(data=X_test).infer_objects()
362365
return X_train, X_test
363366

test/test_data/test_feature_validator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,17 @@ def test_featurevalidator_unsupported_numpy(input_data_featuretest):
231231
),
232232
indirect=True
233233
)
234-
def test_featurevalidator_unsupported_pandas(input_data_featuretest):
234+
def test_featurevalidator_categorical_nan(input_data_featuretest):
235235
validator = TabularFeatureValidator()
236-
with pytest.raises(ValueError, match=r"Categorical features in a dataframe.*missing/NaN"):
237-
validator.fit(input_data_featuretest)
236+
validator.fit(input_data_featuretest)
237+
transformed_X = validator.transform(input_data_featuretest)
238+
assert any(pd.isna(input_data_featuretest))
239+
assert any((-1 in categories) or ('-1' in categories) for categories in
240+
validator.encoder.named_transformers_['encoder'].categories_)
241+
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
242+
assert np.issubdtype(transformed_X.dtype, np.number)
243+
assert validator._is_fitted
244+
assert isinstance(transformed_X, np.ndarray)
238245

239246

240247
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)