Skip to content

Commit a3d40ac

Browse files
authored
Merge pull request #126 from franchuterivera/refactor_development_benchmarkproblems
Fixes to address automlbenchmark problems
2 parents 6da24a8 + c95fbf3 commit a3d40ac

File tree

5 files changed

+181
-25
lines changed

5 files changed

+181
-25
lines changed

autoPyTorch/api/base_task.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,8 +858,11 @@ def _search(
858858
saveable_trajectory = \
859859
[list(entry[:2]) + [entry[2].get_dictionary()] + list(entry[3:])
860860
for entry in self.trajectory]
861-
with open(trajectory_filename, 'w') as fh:
862-
json.dump(saveable_trajectory, fh)
861+
try:
862+
with open(trajectory_filename, 'w') as fh:
863+
json.dump(saveable_trajectory, fh)
864+
except Exception as e:
865+
self._logger.warning(f"Cannot save {trajectory_filename} due to {e}...")
863866
except Exception as e:
864867
self._logger.exception(str(e))
865868
raise

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 125 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,13 @@ def _fit(
5757
if len(self.dtypes) != 0:
5858
self.dtypes[list(X.columns).index(column)] = X[column].dtype
5959

60+
if not X.select_dtypes(include='object').empty:
61+
X = self.infer_objects(X)
62+
6063
self.enc_columns, self.feat_type = self._get_columns_to_encode(X)
6164

6265
if len(self.enc_columns) > 0:
63-
# impute missing values before encoding,
64-
# remove once sklearn natively supports
65-
# it in ordinal encoding. Sklearn issue:
66-
# "https://github.com/scikit-learn/scikit-learn/issues/17123)"
67-
for column in self.enc_columns:
68-
if X[column].isna().any():
69-
missing_value: typing.Union[int, str] = -1
70-
# make sure for a string column we give
71-
# string missing value else we give numeric
72-
if type(X[column][0]) == str:
73-
missing_value = str(missing_value)
74-
X[column] = X[column].cat.add_categories([missing_value])
75-
X[column] = X[column].fillna(missing_value)
66+
X = self.impute_nan_in_categories(X)
7667

7768
self.encoder = ColumnTransformer(
7869
[
@@ -160,6 +151,10 @@ def transform(
160151
if X[column].isna().all():
161152
X[column] = pd.to_numeric(X[column])
162153

154+
# Also remove the object dtype for new data
155+
if not X.select_dtypes(include='object').empty:
156+
X = self.infer_objects(X)
157+
163158
# Check the data here so we catch problems on new test data
164159
self._check_data(X)
165160

@@ -172,18 +167,32 @@ def transform(
172167
for column in X.columns:
173168
if X[column].isna().all():
174169
X[column] = pd.to_numeric(X[column])
170+
171+
# We also need to fillna on the transformation
172+
# in case test data is provided
173+
X = self.impute_nan_in_categories(X)
174+
175175
X = self.encoder.transform(X)
176176

177177
# Sparse related transformations
178178
# Not all sparse format support index sorting
179179
if scipy.sparse.issparse(X) and hasattr(X, 'sort_indices'):
180180
X.sort_indices()
181181

182-
return sklearn.utils.check_array(
183-
X,
184-
force_all_finite=False,
185-
accept_sparse='csr'
186-
)
182+
try:
183+
X = sklearn.utils.check_array(
184+
X,
185+
force_all_finite=False,
186+
accept_sparse='csr'
187+
)
188+
except Exception as e:
189+
self.logger.exception(f"Conversion failed for input {X.dtypes} {X}"
190+
"This means AutoPyTorch was not able to properly "
191+
"Extract the dtypes of the provided input features. "
192+
"Please try to manually cast it to a supported "
193+
"numerical or categorical values.")
194+
raise e
195+
return X
187196

188197
def _check_data(
189198
self,
@@ -231,6 +240,10 @@ def _check_data(
231240
# If entered here, we have a pandas dataframe
232241
X = typing.cast(pd.DataFrame, X)
233242

243+
# Handle objects if possible
244+
if not X.select_dtypes(include='object').empty:
245+
X = self.infer_objects(X)
246+
234247
# Define the column to be encoded here as the feature validator is fitted once
235248
# per estimator
236249
enc_columns, _ = self._get_columns_to_encode(X)
@@ -245,6 +258,7 @@ def _check_data(
245258
)
246259
else:
247260
self.column_order = column_order
261+
248262
dtypes = [dtype.name for dtype in X.dtypes]
249263
if len(self.dtypes) > 0:
250264
if self.dtypes != dtypes:
@@ -379,3 +393,96 @@ def numpy_array_to_pandas(
379393
pd.DataFrame
380394
"""
381395
return pd.DataFrame(X).infer_objects().convert_dtypes()
396+
397+
def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
398+
"""
399+
In case the input contains object columns, their type is inferred if possible
400+
401+
This has to be done once, so the test and train data are treated equally
402+
403+
Arguments:
404+
X (pd.DataFrame):
405+
data to be interpreted.
406+
407+
Returns:
408+
pd.DataFrame
409+
"""
410+
if hasattr(self, 'object_dtype_mapping'):
411+
# Mypy does not process the has attr. This dict is defined below
412+
for key, dtype in self.object_dtype_mapping.items(): # type: ignore[has-type]
413+
if 'int' in dtype.name:
414+
# In the case train data was interpreted as int
415+
# and test data was interpreted as float, because of 0.0
416+
# for example, honor training data
417+
X[key] = X[key].applymap(np.int64)
418+
else:
419+
try:
420+
X[key] = X[key].astype(dtype.name)
421+
except Exception as e:
422+
# Try inference if possible
423+
self.logger.warning(f"Tried to cast column {key} to {dtype} caused {e}")
424+
pass
425+
else:
426+
X = X.infer_objects()
427+
for column in X.columns:
428+
if not is_numeric_dtype(X[column]):
429+
X[column] = X[column].astype('category')
430+
self.object_dtype_mapping = {column: X[column].dtype for column in X.columns}
431+
self.logger.debug(f"Infer Objects: {self.object_dtype_mapping}")
432+
return X
433+
434+
def impute_nan_in_categories(self, X: pd.DataFrame) -> pd.DataFrame:
435+
"""
436+
impute missing values before encoding,
437+
remove once sklearn natively supports
438+
it in ordinal encoding. Sklearn issue:
439+
"https://github.com/scikit-learn/scikit-learn/issues/17123)"
440+
441+
Arguments:
442+
X (pd.DataFrame):
443+
data to be interpreted.
444+
445+
Returns:
446+
pd.DataFrame
447+
"""
448+
449+
# To be on the safe side, map always to the same missing
450+
# value per column
451+
if not hasattr(self, 'dict_nancol_to_missing'):
452+
self.dict_missing_value_per_col: typing.Dict[str, typing.Any] = {}
453+
454+
# First make sure that we do not alter the type of the column which cause:
455+
# TypeError: '<' not supported between instances of 'int' and 'str'
456+
# in the encoding
457+
for column in self.enc_columns:
458+
if X[column].isna().any():
459+
if column not in self.dict_missing_value_per_col:
460+
try:
461+
float(X[column].dropna().values[0])
462+
can_cast_as_number = True
463+
except Exception:
464+
can_cast_as_number = False
465+
if can_cast_as_number:
466+
# In this case, we expect to have a number as category
467+
# it might be string, but its value represent a number
468+
missing_value: typing.Union[str, int] = '-1' if isinstance(X[column].dropna().values[0],
469+
str) else -1
470+
else:
471+
missing_value = 'Missing!'
472+
473+
# Make sure this missing value is not seen before
474+
# Do this check for categorical columns
475+
# else modify the value
476+
if hasattr(X[column], 'cat'):
477+
while missing_value in X[column].cat.categories:
478+
if isinstance(missing_value, str):
479+
missing_value += '0'
480+
else:
481+
missing_value += missing_value
482+
self.dict_missing_value_per_col[column] = missing_value
483+
484+
# Convert the frame in place
485+
X[column].cat.add_categories([self.dict_missing_value_per_col[column]],
486+
inplace=True)
487+
X.fillna({column: self.dict_missing_value_per_col[column]}, inplace=True)
488+
return X

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201
134134
def get_additional_run_info(self) -> None: # pylint: disable=R0201
135135
return None
136136

137+
def get_pipeline_representation(self) -> Dict[str, str]:
138+
return {
139+
'Preprocessing': 'None',
140+
'Estimator': 'Dummy',
141+
}
142+
137143
@staticmethod
138144
def get_default_pipeline_options() -> Dict[str, Any]:
139145
return {'budget_type': 'epochs',

test/test_api/test_api.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import pickle
33
import sys
4+
import unittest
45

56
import numpy as np
67

@@ -21,6 +22,7 @@
2122
CrossValTypes,
2223
HoldoutValTypes,
2324
)
25+
from autoPyTorch.optimizer.smbo import AutoMLSMBO
2426

2527

2628
# Fixtures
@@ -344,3 +346,45 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
344346
with open(dump_file, 'rb') as f:
345347
restored_estimator = pickle.load(f)
346348
restored_estimator.predict(X_test)
349+
350+
351+
@pytest.mark.parametrize('openml_id', (
352+
1590, # Adult to test NaN in categorical columns
353+
))
354+
def test_tabular_input_support(openml_id, backend):
355+
"""
356+
Make sure we can process inputs with NaN in categorical and Object columns
357+
when the later is possible
358+
"""
359+
360+
# Get the data and check that contents of data-manager make sense
361+
X, y = sklearn.datasets.fetch_openml(
362+
data_id=int(openml_id),
363+
return_X_y=True, as_frame=True
364+
)
365+
366+
# Make sure we are robust against objects
367+
X[X.columns[0]] = X[X.columns[0]].astype(object)
368+
369+
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
370+
X, y, random_state=1)
371+
# Search for a good configuration
372+
estimator = TabularClassificationTask(
373+
backend=backend,
374+
resampling_strategy=HoldoutValTypes.holdout_validation,
375+
ensemble_size=0,
376+
)
377+
378+
estimator._do_dummy_prediction = unittest.mock.MagicMock()
379+
380+
with unittest.mock.patch.object(AutoMLSMBO, 'run_smbo') as AutoMLSMBOMock:
381+
AutoMLSMBOMock.return_value = ({}, {}, 'epochs')
382+
estimator.search(
383+
X_train=X_train, y_train=y_train,
384+
X_test=X_test, y_test=y_test,
385+
optimize_metric='accuracy',
386+
total_walltime_limit=150,
387+
func_eval_time_limit=50,
388+
traditional_per_total_budget=0,
389+
load_models=False,
390+
)

test/test_data/test_feature_validator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def test_featurevalidator_categorical_nan(input_data_featuretest):
236236
validator.fit(input_data_featuretest)
237237
transformed_X = validator.transform(input_data_featuretest)
238238
assert any(pd.isna(input_data_featuretest))
239-
assert any((-1 in categories) or ('-1' in categories) for categories in
239+
assert any((-1 in categories) or ('-1' in categories) or ('Missing!' in categories) for categories in
240240
validator.encoder.named_transformers_['encoder'].categories_)
241241
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
242242
assert np.issubdtype(transformed_X.dtype, np.number)
@@ -328,10 +328,6 @@ def test_features_unsupported_calls_are_raised():
328328
validator.fit(
329329
pd.DataFrame({'datetime': [pd.Timestamp('20180310')]})
330330
)
331-
with pytest.raises(ValueError, match="has invalid type object"):
332-
validator.fit(
333-
pd.DataFrame({'string': [TabularFeatureValidator()]})
334-
)
335331
with pytest.raises(ValueError, match=r"AutoPyTorch only supports.*yet, the provided input"):
336332
validator.fit({'input1': 1, 'input2': 2})
337333
with pytest.raises(ValueError, match=r"has unsupported dtype string"):

0 commit comments

Comments
 (0)