Skip to content

Commit aa343f6

Browse files
committed
fix tests
1 parent e6c37a0 commit aa343f6

File tree

5 files changed

+36
-414
lines changed

5 files changed

+36
-414
lines changed

autoPyTorch/api/tabular_classification.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -384,13 +384,6 @@ def search(
384384
dataset_name=dataset_name
385385
)
386386

387-
if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
388-
raise ValueError(
389-
'Hyperparameter optimization requires a validation split. '
390-
'Expected `self.resampling_strategy` to be either '
391-
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
392-
)
393-
394387
return self._search(
395388
dataset=self.dataset,
396389
optimize_metric=optimize_metric,

autoPyTorch/api/tabular_regression.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -384,13 +384,6 @@ def search(
384384
dataset_name=dataset_name
385385
)
386386

387-
if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
388-
raise ValueError(
389-
'Hyperparameter optimization requires a validation split. '
390-
'Expected `self.resampling_strategy` to be either '
391-
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
392-
)
393-
394387
return self._search(
395388
dataset=self.dataset,
396389
optimize_metric=optimize_metric,

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ def _fit(
150150
all_nan_columns = X.columns[X.isna().all()]
151151
for col in all_nan_columns:
152152
X[col] = pd.to_numeric(X[col])
153+
154+
# Handle objects if possible
155+
exist_object_columns = has_object_columns(X.dtypes.values)
156+
if exist_object_columns:
157+
X = self.infer_objects(X)
158+
153159
self.dtypes = [dt.name for dt in X.dtypes] # Also note this change in self.dtypes
154160
self.all_nan_columns = set(all_nan_columns)
155161

@@ -260,20 +266,22 @@ def transform(
260266

261267
if hasattr(X, "iloc") and not scipy.sparse.issparse(X):
262268
X = cast(Type[pd.DataFrame], X)
263-
if self.all_nan_columns is not None:
264-
for column in X.columns:
265-
if column in self.all_nan_columns:
266-
if not X[column].isna().all():
267-
X[column] = np.nan
268-
X[column] = pd.to_numeric(X[column])
269-
if len(self.categorical_columns) > 0:
270-
if self.column_transformer is None:
271-
raise AttributeError("Expect column transformer to be built"
272-
"if there are categorical columns")
273-
categorical_columns = self.column_transformer.transformers_[0][-1]
274-
for column in categorical_columns:
275-
if X[column].isna().all():
276-
X[column] = X[column].astype('object')
269+
270+
if self.all_nan_columns is None:
271+
raise ValueError('_fit must be called before calling transform')
272+
273+
for col in list(self.all_nan_columns):
274+
X[col] = np.nan
275+
X[col] = pd.to_numeric(X[col])
276+
277+
if len(self.categorical_columns) > 0:
278+
if self.column_transformer is None:
279+
raise AttributeError("Expect column transformer to be built"
280+
"if there are categorical columns")
281+
categorical_columns = self.column_transformer.transformers_[0][-1]
282+
for column in categorical_columns:
283+
if X[column].isna().all():
284+
X[column] = X[column].astype('object')
277285

278286
# Check the data here so we catch problems on new test data
279287
self._check_data(X)
@@ -366,10 +374,10 @@ def _check_data(
366374
self.column_order = column_order
367375

368376
dtypes = [dtype.name for dtype in X.dtypes]
369-
377+
diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]]
370378
if len(self.dtypes) == 0:
371379
self.dtypes = dtypes
372-
elif self.dtypes != dtypes:
380+
elif not self._is_datasets_consistent(diff_cols, X):
373381
raise ValueError("The dtype of the features must not be changed after fit(), but"
374382
" the dtypes of some columns are different between training ({}) and"
375383
" test ({}) datasets.".format(self.dtypes, dtypes))
@@ -517,11 +525,17 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
517525
self.logger.warning(f'Casting the column {key} to {dtype} caused the exception {e}')
518526
pass
519527
else:
520-
# Calling for the first time to infer the categories
521-
X = X.infer_objects()
522-
for column, data_type in zip(X.columns, X.dtypes):
523-
if not is_numeric_dtype(data_type):
524-
X[column] = X[column].astype('category')
528+
if len(self.dtypes) != 0:
529+
# when train data has no object dtype, but test does
530+
# we prioritise the datatype given in training data
531+
for column, data_type in zip(X.columns, self.dtypes):
532+
X[column] = X[column].astype(data_type)
533+
else:
534+
# Calling for the first time to infer the categories
535+
X = X.infer_objects()
536+
for column, data_type in zip(X.columns, X.dtypes):
537+
if not is_numeric_dtype(data_type):
538+
X[column] = X[column].astype('category')
525539

526540
# only numerical attributes and categories
527541
self.object_dtype_mapping = {column: data_type for column, data_type in zip(X.columns, X.dtypes)}

0 commit comments

Comments
 (0)